瀏覽代碼

modify: 在线学习代码优化

yq 4 天之前
父節點
當前提交
5f36d1e476
共有 1 個文件被更改,包括 7 次插入8 次删除
  1. 7 8
      online_learning/trainer.py

+ 7 - 8
online_learning/trainer.py

@@ -43,6 +43,7 @@ class OnlineLearningTrainer:
         self._columns = None
         self._model_original: LR
         self._model_optimized: LR
+        self._df_param_optimized = None
         self.sc_woebin = None
         self.card_cfg = None
         self.card = None
@@ -262,10 +263,11 @@ class OnlineLearningTrainer:
         optimizer = optim.Adam(self._model_optimized.parameters(), lr=self._ol_config.lr)
 
         df_param_columns = self._columns + ["auc_test", "ks_test", "epoch", "loss_train", "loss_test"]
-        df_param = pd.DataFrame(columns=df_param_columns)
+        self._df_param_optimized = pd.DataFrame(columns=df_param_columns)
+        
         # 优化前
         loss_train = 0
-        df_param.loc[len(df_param)] = _get_param_optimized(self._model_original, -1)
+        self._df_param_optimized.loc[len(self._df_param_optimized)] = _get_param_optimized(self._model_original, -1)
         for epoch in tqdm(range(epochs)):
             data_len = len(train_x)
             for i in range(math.ceil(data_len / batch_size)):
@@ -279,9 +281,7 @@ class OnlineLearningTrainer:
                 optimizer.step()
                 loss_train = loss.detach().item()
             # 测试集评估
-            df_param.loc[len(df_param)] = _get_param_optimized(self._model_optimized, epoch)
-
-        context.set(ContextEnum.PARAM_OPTIMIZED, df_param)
+            self._df_param_optimized.loc[len(self._df_param_optimized)] = _get_param_optimized(self._model_optimized, epoch)
 
     def save(self):
 
@@ -316,13 +316,12 @@ class OnlineLearningTrainer:
         return OnlineLearningTrainer(ol_config=ol_config)
 
     def report(self, epoch: int = None):
-        df_param = context.get(ContextEnum.PARAM_OPTIMIZED)
-        self._model_optimized = self._f_get_best_model(df_param, epoch)
+        self._model_optimized = self._f_get_best_model(self._df_param_optimized, epoch)
 
         if self._ol_config.jupyter_print:
             from IPython import display
             f_display_title(display, "模型系数优化过程")
-            display.display(df_param)
+            display.display(self._df_param_optimized)
 
         metric_value_dict = {}