Переглянути джерело

modify: 修改代码交互

yq 4 місяців тому
батько
коміт
a0d8561b44

+ 2 - 15
entitys/train_config_entity.py

@@ -8,26 +8,13 @@ import json
 import os
 
 from commom import GeneralException
-from enums import ResultCodesEnum, ModelEnum
+from enums import ResultCodesEnum
 
 
 class TrainConfigEntity():
-    def __init__(self, model_type=str, lr: float = None):
-        # 模型类型
-        self._model_type = model_type
+    def __init__(self, lr: float = None):
         # 学习率
         self._lr = lr
-        # 报告模板
-        if model_type == ModelEnum.LR.value:
-            self._template_path = "./template/模型开发报告模板_lr.docx"
-
-    @property
-    def template_path(self):
-        return self._template_path
-
-    @property
-    def model_type(self):
-        return self._model_type
 
     @property
     def lr(self):

+ 5 - 2
feature/filter_strategy_base.py

@@ -12,8 +12,11 @@ from entitys import DataProcessConfigEntity, DataPreparedEntity, CandidateFeatur
 
 class FilterStrategyBase(metaclass=abc.ABCMeta):
 
-    def __init__(self, data_process_config: DataProcessConfigEntity, *args, **kwargs):
-        self._data_process_config = data_process_config
+    def __init__(self, data_process_config: DataProcessConfigEntity = None, *args, **kwargs):
+        if data_process_config is not None:
+            self._data_process_config = data_process_config
+        else:
+            self._data_process_config = DataProcessConfigEntity(*args, **kwargs)
 
     @property
     def data_process_config(self):

+ 10 - 7
feature/filter_strategy_factory.py

@@ -4,20 +4,23 @@
 @time: 2024/11/25
 @desc: 特征筛选策略工厂
 """
-from entitys import DataProcessConfigEntity
-from enums import FilterStrategyEnum
+from typing import Type
+
+from commom import GeneralException
+from enums import FilterStrategyEnum, ResultCodesEnum
 from .filter_strategy_base import FilterStrategyBase
 from .strategy_iv import StrategyIv
 
 
 class FilterStrategyFactory():
 
-    def __init__(self, data_process_config: DataProcessConfigEntity, *args, **kwargs):
-        self._data_process_config = data_process_config
+    def __init__(self, ):
         self.strategy_map = {
-            FilterStrategyEnum.IV.value: StrategyIv(data_process_config, *args, **kwargs)
+            FilterStrategyEnum.IV.value: StrategyIv
         }
 
-    def get_strategy(self, ) -> FilterStrategyBase:
-        strategy = self.strategy_map.get(self._data_process_config.feature_search_strategy)
+    def get_strategy(self, strategy: str) -> Type[FilterStrategyBase]:
+        if strategy not in self.strategy_map.keys():
+            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"特征搜索策略【{strategy}】不存在")
+        strategy = self.strategy_map.get(strategy)
         return strategy

+ 2 - 18
model/__init__.py

@@ -4,23 +4,7 @@
 @time: 2023/12/28
 @desc: 模型相关
 """
-from commom import GeneralException
-from enums import ModelEnum, ResultCodesEnum
 from .model_base import ModelBase
-from .model_lr import ModelLr
+from .model_factory import ModelFactory
 
-__all__ = ['ModelBase', 'f_get_model']
-
-model_map = {
-    ModelEnum.LR.value: ModelLr
-}
-
-
-def f_get_model(model_type: str):
-    if model_type not in model_map.keys():
-        raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型【{model_type}】没有实现")
-    return model_map[model_type]
-
-
-if __name__ == "__main__":
-    pass
+__all__ = ['ModelBase', 'ModelFactory']

+ 9 - 2
model/model_base.py

@@ -14,8 +14,15 @@ from entitys import TrainConfigEntity, DataPreparedEntity, MetricFucEntity
 
 class ModelBase(metaclass=abc.ABCMeta):
 
-    def __init__(self, train_config: TrainConfigEntity):
-        self._train_config = train_config
+    def __init__(self, train_config: TrainConfigEntity = None, *args, **kwargs):
+        if train_config is not None:
+            self._train_config = train_config
+        else:
+            self._train_config = TrainConfigEntity(*args, **kwargs)
+
+    @abc.abstractmethod
+    def get_template_path(self, ) -> str:
+        pass
 
     @abc.abstractmethod
     def train(self, data: DataPreparedEntity, *args, **kwargs) -> Dict[str, MetricFucEntity]:

+ 25 - 0
model/model_factory.py

@@ -0,0 +1,25 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/12/3
+@desc: 模型工厂
+"""
+from typing import Type
+
+from commom import GeneralException
+from enums import ModelEnum, ResultCodesEnum
+from model import ModelBase
+from .model_lr import ModelLr
+
+
+class ModelFactory():
+
+    def __init__(self, ):
+        self.model_map = {
+            ModelEnum.LR.value: ModelLr
+        }
+
+    def get_model(self, model_type: str) -> Type[ModelBase]:
+        if model_type not in self.model_map.keys():
+            raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型【{model_type}】没有实现")
+        return self.model_map.get(model_type)

+ 8 - 2
model/model_lr.py

@@ -18,10 +18,16 @@ from .model_base import ModelBase
 
 
 class ModelLr(ModelBase):
-    def __init__(self, train_config: TrainConfigEntity):
-        super().__init__(train_config)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # 报告模板
+        self._template_path = "./template/模型开发报告模板_lr.docx"
         self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
 
+    @property
+    def get_template_path(self):
+        return self._template_path
+
     def train(self, data: DataPreparedEntity, *args, **kwargs) -> Dict[str, MetricFucEntity]:
         bins = kwargs["bins"]
         data_split_original: DataSplitEntity = kwargs["data_split_original"]

+ 21 - 14
train_test.py

@@ -6,30 +6,37 @@
 """
 import time
 
-from entitys import DataSplitEntity, DataProcessConfigEntity, TrainConfigEntity
+from entitys import DataSplitEntity
 from feature import FilterStrategyFactory
+from model import ModelFactory
 from trainer import TrainPipeline
 
 if __name__ == "__main__":
     time_now = time.time()
     import scorecardpy as sc
 
+    # 加载数据
     dat = sc.germancredit()
     dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
-    data = DataSplitEntity(dat[:709], None, dat[709:])
+    data = DataSplitEntity(train_data=dat[:709], val_data=None, test_data=dat[709:])
 
     # 特征处理
-    filter_strategy_factory = FilterStrategyFactory(
-        DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
-    strategy = filter_strategy_factory.get_strategy()
-    candidate_feature = strategy.filter(data)
-    data_prepared = strategy.feature_generate(data, candidate_feature)
-    # 训练
-    train_pipeline = TrainPipeline(TrainConfigEntity.from_config('./config/train_config_template.json'))
-    metric_value_dict_train = train_pipeline.train(data_prepared)
-    # 报告生成
-    metric_value_dict_feature = strategy.feature_report(data, candidate_feature)
-    metric_value_dict_train.update(metric_value_dict_feature)
-    train_pipeline.generate_report(metric_value_dict_train)
+    ## 获取特征筛选策略
+    filter_strategy_factory = FilterStrategyFactory()
+    filter_strategy_clazz = filter_strategy_factory.get_strategy("iv")
+    ## 可传入参数
+    filter_strategy = filter_strategy_clazz(y_column="creditability")
+    ## 也可从配置文件加载
+    # filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
+
+    # 选择模型
+    model_factory = ModelFactory()
+    model_clazz = model_factory.get_model("lr")
+    model = model_clazz()
+
+    # 训练并生成报告
+    train_pipeline = TrainPipeline(filter_strategy, model, data)
+    train_pipeline.train()
+    train_pipeline.generate_report()
 
     print(time.time() - time_now)

+ 20 - 12
trainer/train.py

@@ -4,26 +4,34 @@
 @time: 2024/11/1
 @desc: 模型训练管道
 """
-from typing import Dict
 
-from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity
+from entitys import DataSplitEntity
+from feature.filter_strategy_base import FilterStrategyBase
 from init import f_get_save_path
-from model import f_get_model
+from model import ModelBase
 from monitor.report_generate import Report
 
 
 class TrainPipeline():
-    def __init__(self, train_config: TrainConfigEntity):
-        self._train_config = train_config
-        model_clazz = f_get_model(self._train_config.model_type)
-        self.model = model_clazz(self._train_config)
+    def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity):
+        self._filter_strategy = filter_strategy
+        self._model = model
+        self._data = data
 
-    def train(self, data: DataPreparedEntity) -> Dict[str, MetricFucEntity]:
-        metric_value_dict = self.model.train(data, *data.args, **data.kwargs)
-        return metric_value_dict
+    def train(self, ):
+        # 处理数据,获取候选特征
+        candidate_feature = self._filter_strategy.filter(self._data)
+        # 生成训练数据
+        data_prepared = self._filter_strategy.feature_generate(self._data, candidate_feature)
+        # 特征信息
+        metric_value_dict_feature = self._filter_strategy.feature_report(self._data, candidate_feature)
 
-    def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]):
-        Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx"))
+        metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs)
+        self.metric_value_dict = metric_value_dict_feature.update(metric_value_dict_train)
+
+    def generate_report(self, ):
+        Report.generate_report(self.metric_value_dict, self._model.get_template_path(),
+                               save_path=f_get_save_path("模型报告.docx"))
 
 
 if __name__ == "__main__":