|
@@ -44,6 +44,7 @@ class OnlineLearningTrainerXgb:
|
|
|
self._model_optimized_list = []
|
|
|
self._pipeline_original: PMMLPipeline
|
|
|
self._pipeline_optimized: PMMLPipeline
|
|
|
+ self.model_optimized: xgb.XGBClassifier
|
|
|
|
|
|
# 报告模板
|
|
|
self._template_path = os.path.join(dirname(dirname(realpath(__file__))),
|
|
@@ -61,6 +62,12 @@ class OnlineLearningTrainerXgb:
|
|
|
self._pipeline_original = joblib.load(path_model)
|
|
|
self._pipeline_optimized = joblib.load(path_model)
|
|
|
print(f"model load from【{path_model}】success.")
|
|
|
+ path_model = os.path.join(path, FileEnum.MODEL_XGB.value)
|
|
|
+ if os.path.isfile(path_model):
|
|
|
+ model = xgb.XGBClassifier()
|
|
|
+ model.load_model(path_model)
|
|
|
+ self._pipeline_optimized.steps[-1] = ("classifier", model)
|
|
|
+ print(f"model load from【{path_model}】success.")
|
|
|
|
|
|
def _f_rewrite_pmml(self, path_pmml: str):
|
|
|
with open(path_pmml, mode="r", encoding="utf-8") as f:
|
|
@@ -184,7 +191,7 @@ class OnlineLearningTrainerXgb:
|
|
|
|
|
|
model_original: xgb.XGBClassifier = self._pipeline_original.steps[-1][1]
|
|
|
ntree = model_original.n_estimators if model_original.best_ntree_limit is None else model_original.best_ntree_limit
|
|
|
- model_optimized = xgb.XGBClassifier(
|
|
|
+ self.model_optimized = xgb.XGBClassifier(
|
|
|
n_estimators=n_estimators if n_estimators else ntree,
|
|
|
updater="refresh",
|
|
|
process_type="update",
|
|
@@ -192,7 +199,7 @@ class OnlineLearningTrainerXgb:
|
|
|
learning_rate=self._ol_config.lr,
|
|
|
random_state=self._ol_config.random_state,
|
|
|
)
|
|
|
- self._pipeline_optimized.steps[-1] = ("classifier", model_optimized)
|
|
|
+ self._pipeline_optimized.steps[-1] = ("classifier", self.model_optimized)
|
|
|
with silent_print():
|
|
|
self._pipeline_optimized.fit(train_data, train_data[y_column],
|
|
|
classifier__verbose=False,
|
|
@@ -237,6 +244,10 @@ class OnlineLearningTrainerXgb:
|
|
|
path_model = self._ol_config.f_get_save_path(FileEnum.MODEL.value)
|
|
|
joblib.dump(self._pipeline_optimized, path_model)
|
|
|
print(f"model save to【{path_model}】success. ")
|
|
|
+ # 在xgb的增量学习下直接保存pipeline会出错,所以这里需要单独保存xgb model,然后进行复原
|
|
|
+ path_model = self._ol_config.f_get_save_path(FileEnum.MODEL_XGB.value)
|
|
|
+ self.model_optimized.save_model(path_model)
|
|
|
+ print(f"model save to【{path_model}】success. ")
|
|
|
|
|
|
@staticmethod
|
|
|
def load(path: str):
|