Bläddra i källkod

modify: 代码优化

yq 3 dagar sedan
förälder
incheckning
f9be3dfca4
4 ändrade filer med 8 tillägg och 8 borttagningar
  1. 2 2
      __init__.py
  2. 2 2
      ol_test.py
  3. 2 2
      online_learning/__init__.py
  4. 2 2
      online_learning/trainer_lr.py

+ 2 - 2
__init__.py

@@ -10,7 +10,7 @@ from os.path import dirname, realpath
 
 sys.path.append(dirname(realpath(__file__)))
 
-from online_learning import OnlineLearningTrainer
+from online_learning import OnlineLearningTrainerLr
 from pipeline import Pipeline
 from data import DataLoaderMysql
 from entitys import DbConfigEntity, DataSplitEntity
@@ -18,4 +18,4 @@ from monitor import MonitorMetric
 from metrics import MetricBase
 
 __all__ = ['MonitorMetric', 'MetricBase', 'DataLoaderMysql', 'DbConfigEntity',
-           'DataSplitEntity', 'Pipeline', 'OnlineLearningTrainer']
+           'DataSplitEntity', 'Pipeline', 'OnlineLearningTrainerLr']

+ 2 - 2
ol_test.py

@@ -7,7 +7,7 @@
 import time
 
 from entitys import DataSplitEntity
-from online_learning import OnlineLearningTrainer
+from online_learning import OnlineLearningTrainerLr
 
 
 if __name__ == "__main__":
@@ -47,7 +47,7 @@ if __name__ == "__main__":
     }
 
     # 训练并生成报告
-    trainer = OnlineLearningTrainer(data=data, **cfg)
+    trainer = OnlineLearningTrainerLr(data=data, **cfg)
     trainer.train()
     trainer.report()
 

+ 2 - 2
online_learning/__init__.py

@@ -5,6 +5,6 @@
 @desc: 
 """
 
-from .trainer import OnlineLearningTrainer
+from .trainer_lr import OnlineLearningTrainerLr
 
-__all__ = ['OnlineLearningTrainer']
+__all__ = ['OnlineLearningTrainerLr']

+ 2 - 2
online_learning/trainer.py → online_learning/trainer_lr.py

@@ -33,7 +33,7 @@ from .utils import LR
 init()
 
 
-class OnlineLearningTrainer:
+class OnlineLearningTrainerLr:
     def __init__(self, data: DataSplitEntity = None, ol_config: OnlineLearningConfigEntity = None, *args, **kwargs):
         if ol_config is not None:
             self._ol_config = ol_config
@@ -313,7 +313,7 @@ class OnlineLearningTrainer:
     def load(path: str):
         ol_config = OnlineLearningConfigEntity.from_config(path)
         ol_config._path_resources = path
-        return OnlineLearningTrainer(ol_config=ol_config)
+        return OnlineLearningTrainerLr(ol_config=ol_config)
 
     def report(self, epoch: int = None):
         self._model_optimized = self._f_get_best_model(self._df_param_optimized, epoch)