|
@@ -42,6 +42,7 @@ class ModelXgb(ModelBase):
|
|
self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
|
|
self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
|
|
self.pipeline: PMMLPipeline
|
|
self.pipeline: PMMLPipeline
|
|
self.model = xgb.XGBClassifier
|
|
self.model = xgb.XGBClassifier
|
|
|
|
+ self._test_case = None
|
|
|
|
|
|
def _f_rewrite_pmml(self, path_pmml: str):
|
|
def _f_rewrite_pmml(self, path_pmml: str):
|
|
with open(path_pmml, mode="r", encoding="utf-8") as f:
|
|
with open(path_pmml, mode="r", encoding="utf-8") as f:
|
|
@@ -170,6 +171,10 @@ class ModelXgb(ModelBase):
|
|
joblib.dump(self.pipeline, path_model)
|
|
joblib.dump(self.pipeline, path_model)
|
|
print(f"model save to【{path_model}】success. ")
|
|
print(f"model save to【{path_model}】success. ")
|
|
|
|
|
|
|
|
+ path = self.ml_config.f_get_save_path(FileEnum.TEST_CASE.value)
|
|
|
|
+ self._test_case.to_csv(path, encoding="utf-8")
|
|
|
|
+ print(f"test case save to【{path}】success. ")
|
|
|
|
+
|
|
def model_load(self, path: str, *args, **kwargs):
|
|
def model_load(self, path: str, *args, **kwargs):
|
|
if not os.path.isdir(path):
|
|
if not os.path.isdir(path):
|
|
raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"【{path}】不是文件夹")
|
|
raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"【{path}】不是文件夹")
|
|
@@ -265,6 +270,11 @@ class ModelXgb(ModelBase):
|
|
if self.ml_config.jupyter_print:
|
|
if self.ml_config.jupyter_print:
|
|
self.jupyter_print(metric_value_dict)
|
|
self.jupyter_print(metric_value_dict)
|
|
|
|
|
|
|
|
+ # 测试案例
|
|
|
|
+ self._test_case = data.test_data.copy()
|
|
|
|
+ test_score = self.prob(test_data)
|
|
|
|
+ self._test_case["score"] = test_score
|
|
|
|
+
|
|
return metric_value_dict
|
|
return metric_value_dict
|
|
|
|
|
|
def jupyter_print(self, metric_value_dict=Dict[str, MetricFucResultEntity], *args, **kwargs):
|
|
def jupyter_print(self, metric_value_dict=Dict[str, MetricFucResultEntity], *args, **kwargs):
|