# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: 模型训练管道 """ from typing import Dict from entitys import DataSplitEntity, MetricFucEntity from feature.filter_strategy_base import FilterStrategyBase from init import init from model import ModelBase from monitor.report_generate import Report init() class TrainPipeline(): def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity): self._filter_strategy = filter_strategy self._model = model self._data = data self._model._train_config.set_save_path_func(self._filter_strategy.data_process_config.f_get_save_path) self._model._data_process_config = self._filter_strategy.data_process_config def train(self, ) -> Dict[str, MetricFucEntity]: # 处理数据,获取候选特征 candidate_feature, numeric_candidate_dict_all = self._filter_strategy.filter(self._data) # 生成训练数据 data_prepared = self._filter_strategy.feature_generate(self._data, candidate_feature) # 特征信息 metric_value_dict_feature = self._filter_strategy.feature_report(self._data, candidate_feature, numeric_candidate_dict_all) metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs) self.metric_value_dict = {**metric_value_dict_feature, **metric_value_dict_train} return self.metric_value_dict def generate_report(self, ): save_path = self._filter_strategy.data_process_config.f_get_save_path("模型报告.docx") Report.generate_report(self.metric_value_dict, self._model.get_template_path(), save_path=save_path) print(f"模型报告文件储存路径:{save_path}") if __name__ == "__main__": pass