Просмотр исходного кода

bugfix: xgb在线学习模型保存问题

yq 19 часов назад
Родитель
Сommit
ed1e60a7d0
2 измененных файлов с 14 добавлено и 2 удалено
  1. 1 0
      enums/file_enum.py
  2. 13 2
      online_learning/trainer_xgb.py

+ 1 - 0
enums/file_enum.py

@@ -16,6 +16,7 @@ class FileEnum(Enum):
     CARD_CFG = "card.cfg"
     COEF = "coef.json"
     MODEL = "model.pkl"
+    MODEL_XGB = "xgb.bin"
     PMML = "model.pmml"
 
 

+ 13 - 2
online_learning/trainer_xgb.py

@@ -44,6 +44,7 @@ class OnlineLearningTrainerXgb:
         self._model_optimized_list = []
         self._pipeline_original: PMMLPipeline
         self._pipeline_optimized: PMMLPipeline
+        self.model_optimized: xgb.XGBClassifier
 
         # 报告模板
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))),
@@ -61,6 +62,12 @@ class OnlineLearningTrainerXgb:
         self._pipeline_original = joblib.load(path_model)
         self._pipeline_optimized = joblib.load(path_model)
         print(f"model load from【{path_model}】success.")
+        path_model = os.path.join(path, FileEnum.MODEL_XGB.value)
+        if os.path.isfile(path_model):
+            model = xgb.XGBClassifier()
+            model.load_model(path_model)
+            self._pipeline_optimized.steps[-1] = ("classifier", model)
+        print(f"model load from【{path_model}】success.")
 
     def _f_rewrite_pmml(self, path_pmml: str):
         with open(path_pmml, mode="r", encoding="utf-8") as f:
@@ -184,7 +191,7 @@ class OnlineLearningTrainerXgb:
 
         model_original: xgb.XGBClassifier = self._pipeline_original.steps[-1][1]
         ntree = model_original.n_estimators if model_original.best_ntree_limit is None else model_original.best_ntree_limit
-        model_optimized = xgb.XGBClassifier(
+        self.model_optimized = xgb.XGBClassifier(
             n_estimators=n_estimators if n_estimators else ntree,
             updater="refresh",
             process_type="update",
@@ -192,7 +199,7 @@ class OnlineLearningTrainerXgb:
             learning_rate=self._ol_config.lr,
             random_state=self._ol_config.random_state,
         )
-        self._pipeline_optimized.steps[-1] = ("classifier", model_optimized)
+        self._pipeline_optimized.steps[-1] = ("classifier", self.model_optimized)
         with silent_print():
             self._pipeline_optimized.fit(train_data, train_data[y_column],
                                          classifier__verbose=False,
@@ -237,6 +244,10 @@ class OnlineLearningTrainerXgb:
         path_model = self._ol_config.f_get_save_path(FileEnum.MODEL.value)
         joblib.dump(self._pipeline_optimized, path_model)
         print(f"model save to【{path_model}】success. ")
+        # 在xgb的增量学习下直接保存pipeline会出错,所以这里需要单独保存xgb model,然后进行复原
+        path_model = self._ol_config.f_get_save_path(FileEnum.MODEL_XGB.value)
+        self.model_optimized.save_model(path_model)
+        print(f"model save to【{path_model}】success. ")
 
     @staticmethod
     def load(path: str):