Browse Source

modify: 模型绑定策略

yq 3 days ago
parent
commit
bdc836283d
2 changed files with 11 additions and 4 deletions
  1. 6 2
      entitys/ml_config_entity.py
  2. 5 2
      pipeline/pipeline.py

+ 6 - 2
entitys/ml_config_entity.py

@@ -10,7 +10,7 @@ from typing import List, Union
 
 
 from commom import GeneralException, f_get_datetime
 from commom import GeneralException, f_get_datetime
 from config import BaseConfig
 from config import BaseConfig
-from enums import ResultCodesEnum, FileEnum
+from enums import ResultCodesEnum, FileEnum, ModelEnum, FeatureStrategyEnum
 from init import warning_ignore
 from init import warning_ignore
 
 
 
 
@@ -138,7 +138,11 @@ class MlConfigEntity():
 
 
     @property
     @property
     def feature_strategy(self):
     def feature_strategy(self):
-        return self._feature_strategy
+        if ModelEnum.LR.value == self._model_type:
+            return FeatureStrategyEnum.WOE.value
+
+        if ModelEnum.XGB.value == self._model_type:
+            return FeatureStrategyEnum.NORM.value
 
 
     @property
     @property
     def params_xgb(self):
     def params_xgb(self):

+ 5 - 2
pipeline/pipeline.py

@@ -24,10 +24,13 @@ class Pipeline():
             self._ml_config = ml_config
             self._ml_config = ml_config
         else:
         else:
             self._ml_config = MlConfigEntity(*args, **kwargs)
             self._ml_config = MlConfigEntity(*args, **kwargs)
-        feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy)
-        self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config)
+
         model_clazz = ModelFactory.get_model(self._ml_config.model_type)
         model_clazz = ModelFactory.get_model(self._ml_config.model_type)
         self._model: ModelBase = model_clazz(self._ml_config)
         self._model: ModelBase = model_clazz(self._ml_config)
+
+        feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy)
+        self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config)
+
         self._data = data
         self._data = data
 
 
     def train(self, ):
     def train(self, ):