Kaynağa Gözat

add: 在线学习 pmml模型结果一致率

yq 13 saat önce
ebeveyn
işleme
bfc20aa2e1

+ 7 - 0
entitys/ol_config_entity.py

@@ -25,6 +25,7 @@ class OnlineLearningConfigEntity():
                  epochs: int = 50,
                  columns_anns: dict = {},
                  jupyter_print=False,
+                 save_pmml=True,
                  stress_test=False,
                  stress_sample_times=100,
                  stress_bad_rate_list: List[float] = [],
@@ -50,6 +51,8 @@ class OnlineLearningConfigEntity():
         # jupyter下输出内容
         self._jupyter_print = jupyter_print
 
+        self._save_pmml = save_pmml
+
         # 是否开启下输出内容
         self._stress_test = stress_test
 
@@ -101,6 +104,10 @@ class OnlineLearningConfigEntity():
     def jupyter_print(self):
         return self._jupyter_print
 
+    @property
+    def save_pmml(self):
+        return self._save_pmml
+
     @property
     def stress_test(self):
         return self._stress_test

+ 1 - 1
model/model_xgb.py

@@ -130,8 +130,8 @@ class ModelXgb(ModelBase):
             path_pmml = self.ml_config.f_get_save_path(FileEnum.PMML.value)
             # pipeline = make_pmml_pipeline(self.model)
             sklearn2pmml(self.pipeline, path_pmml, with_repr=True, )
-            print(f"model save to【{path_pmml}】success. ")
             self._f_rewrite_pmml(path_pmml)
+            print(f"model save to【{path_pmml}】success. ")
             # pmml与原生模型结果一致性校验
             model_pmml = Model.fromFile(path_pmml)
             prob_pmml = model_pmml.predict(data.data)["probability(1)"]

+ 29 - 2
online_learning/trainer_xgb.py

@@ -12,7 +12,8 @@ import joblib
 import pandas as pd
 import scorecardpy as sc
 import xgboost as xgb
-from sklearn2pmml import PMMLPipeline
+from pypmml import Model
+from sklearn2pmml import PMMLPipeline, sklearn2pmml
 from tqdm import tqdm
 
 from commom import GeneralException, f_image_crop_white_borders, f_df_to_image, f_display_title, \
@@ -61,15 +62,41 @@ class OnlineLearningTrainerXgb:
         self._pipeline_optimized = joblib.load(path_model)
         print(f"model load from【{path_model}】success.")
 
+    def _f_rewrite_pmml(self, path_pmml: str):
+        with open(path_pmml, mode="r", encoding="utf-8") as f:
+            pmml = f.read()
+            pmml = pmml.replace('optype="categorical" dataType="double"', 'optype="categorical" dataType="string"')
+        with open(path_pmml, mode="w", encoding="utf-8") as f:
+            f.write(pmml)
+            f.flush()
+
     def _f_get_best_model(self, df_param: pd.DataFrame, ntree: int = None):
         if ntree is None:
             df_param_sort = df_param.sort_values(by=["ks_test", "auc_test"], ascending=[False, False])
             print(f"选择最佳参数:\n{df_param_sort.iloc[0].to_dict()}")
-            self._train(df_param_sort.iloc[0][2])
+            self._train(int(df_param_sort.iloc[0][2]))
         else:
             print(f"选择ntree:【{ntree}】的参数:\n{df_param[df_param['ntree'] == ntree].iloc[0].to_dict()}")
             self._train(ntree)
 
+        if self._ol_config.save_pmml:
+            data = self._data.data
+            path_pmml = self._ol_config.f_get_save_path(FileEnum.PMML.value)
+            # pipeline = make_pmml_pipeline(self.model)
+            sklearn2pmml(self._pipeline_optimized, path_pmml, with_repr=True, )
+            self._f_rewrite_pmml(path_pmml)
+            print(f"model save to【{path_pmml}】success. ")
+            # pmml与原生模型结果一致性校验
+            model_pmml = Model.fromFile(path_pmml)
+            prob_pmml = model_pmml.predict(data)["probability(1)"]
+            prob_pipeline = self._pipeline_optimized.predict_proba(data)[:, 1]
+            diff = pd.DataFrame()
+            diff["prob_pmml"] = prob_pmml
+            diff["prob_pipeline"] = prob_pipeline
+            diff["diff"] = diff["prob_pmml"] - diff["prob_pipeline"]
+            diff["diff_format"] = diff["diff"].apply(lambda x: 1 if abs(x) < 0.001 else 0)
+            print(f"pmml模型结果一致率(误差小于0.001):{len(diff) / diff['diff_format'].sum().round(3) * 100}%")
+
     def _f_get_metric_auc_ks(self, model_type: str):
         def _get_auc_ks(data, title):
             y = data[self._ol_config.y_column]