|
@@ -4,26 +4,34 @@
|
|
|
@time: 2024/11/1
|
|
|
@desc: 模型训练管道
|
|
|
"""
|
|
|
-from typing import Dict
|
|
|
|
|
|
-from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity
|
|
|
+from entitys import DataSplitEntity
|
|
|
+from feature.filter_strategy_base import FilterStrategyBase
|
|
|
from init import f_get_save_path
|
|
|
-from model import f_get_model
|
|
|
+from model import ModelBase
|
|
|
from monitor.report_generate import Report
|
|
|
|
|
|
|
|
|
class TrainPipeline():
|
|
|
- def __init__(self, train_config: TrainConfigEntity):
|
|
|
- self._train_config = train_config
|
|
|
- model_clazz = f_get_model(self._train_config.model_type)
|
|
|
- self.model = model_clazz(self._train_config)
|
|
|
+ def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity):
|
|
|
+ self._filter_strategy = filter_strategy
|
|
|
+ self._model = model
|
|
|
+ self._data = data
|
|
|
|
|
|
- def train(self, data: DataPreparedEntity) -> Dict[str, MetricFucEntity]:
|
|
|
- metric_value_dict = self.model.train(data, *data.args, **data.kwargs)
|
|
|
- return metric_value_dict
|
|
|
+ def train(self, ):
|
|
|
+ # 处理数据,获取候选特征
|
|
|
+ candidate_feature = self._filter_strategy.filter(self._data)
|
|
|
+ # 生成训练数据
|
|
|
+ data_prepared = self._filter_strategy.feature_generate(self._data, candidate_feature)
|
|
|
+ # 特征信息
|
|
|
+ metric_value_dict_feature = self._filter_strategy.feature_report(self._data, candidate_feature)
|
|
|
|
|
|
- def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]):
|
|
|
- Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx"))
|
|
|
+ metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs)
|
|
|
+ self.metric_value_dict = metric_value_dict_feature.update(metric_value_dict_train)
|
|
|
+
|
|
|
+ def generate_report(self, ):
|
|
|
+ Report.generate_report(self.metric_value_dict, self._model.get_template_path(),
|
|
|
+ save_path=f_get_save_path("模型报告.docx"))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|