Browse Source

add: test_case

yq 5 hours ago
parent
commit
f259a15096
3 changed files with 24 additions and 0 deletions
  1. 1 0
      enums/file_enum.py
  2. 13 0
      model/model_lr.py
  3. 10 0
      model/model_xgb.py

+ 1 - 0
enums/file_enum.py

@@ -19,6 +19,7 @@ class FileEnum(Enum):
     MODEL = "model.pkl"
     MODEL_XGB = "xgb.bin"
     PMML = "model.pmml"
+    TEST_CASE = "test_case.csv"
 
 
 

+ 13 - 0
model/model_lr.py

@@ -34,6 +34,7 @@ class ModelLr(ModelBase):
         self.card = None
         self.card_cfg = None
         self.coef = None
+        self._test_case = None
 
     def get_report_template_path(self):
         return self._template_path
@@ -106,6 +107,10 @@ class ModelLr(ModelBase):
         df_var_mapping.to_csv(path, encoding="utf-8")
         print(f"model save to【{path}】success. ")
 
+        path = self.ml_config.f_get_save_path(FileEnum.TEST_CASE.value)
+        self._test_case.to_csv(path, encoding="utf-8")
+        print(f"test case save to【{path}】success. ")
+
     def model_load(self, path: str, *args, **kwargs):
         if not os.path.isdir(path):
             raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"【{path}】不是文件夹")
@@ -227,6 +232,14 @@ class ModelLr(ModelBase):
         if self.ml_config.jupyter_print:
             self.jupyter_print(metric_value_dict)
 
+        # 测试案例
+        self._test_case = data.test_data.copy()
+        if len(self.ml_config.rules) != 0:
+            test_score = self.score_rule(self._test_case)
+        else:
+            test_score = self.score(self._test_case)
+        self._test_case["score"] = test_score
+
         return metric_value_dict
 
     def jupyter_print(self, metric_value_dict=Dict[str, MetricFucResultEntity], *args, **kwargs):

+ 10 - 0
model/model_xgb.py

@@ -42,6 +42,7 @@ class ModelXgb(ModelBase):
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
         self.pipeline: PMMLPipeline
         self.model = xgb.XGBClassifier
+        self._test_case = None
 
     def _f_rewrite_pmml(self, path_pmml: str):
         with open(path_pmml, mode="r", encoding="utf-8") as f:
@@ -170,6 +171,10 @@ class ModelXgb(ModelBase):
         joblib.dump(self.pipeline, path_model)
         print(f"model save to【{path_model}】success. ")
 
+        path = self.ml_config.f_get_save_path(FileEnum.TEST_CASE.value)
+        self._test_case.to_csv(path, encoding="utf-8")
+        print(f"test case save to【{path}】success. ")
+
     def model_load(self, path: str, *args, **kwargs):
         if not os.path.isdir(path):
             raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"【{path}】不是文件夹")
@@ -265,6 +270,11 @@ class ModelXgb(ModelBase):
         if self.ml_config.jupyter_print:
             self.jupyter_print(metric_value_dict)
 
+        # 测试案例
+        self._test_case = data.test_data.copy()
+        test_score = self.prob(test_data)
+        self._test_case["score"] = test_score
+
         return metric_value_dict
 
     def jupyter_print(self, metric_value_dict=Dict[str, MetricFucResultEntity], *args, **kwargs):