1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: 模型训练管道
- """
- from typing import Dict
- from entitys import DataSplitEntity, MetricFucEntity
- from feature.filter_strategy_base import FilterStrategyBase
- from init import init
- from model import ModelBase
- from monitor.report_generate import Report
- init()
- class TrainPipeline():
- def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity):
- self._filter_strategy = filter_strategy
- self._model = model
- self._data = data
- self._model._train_config.set_save_path_func(self._filter_strategy.data_process_config.f_get_save_path)
- def train(self, ) -> Dict[str, MetricFucEntity]:
- # 处理数据,获取候选特征
- candidate_feature = self._filter_strategy.filter(self._data)
- # 生成训练数据
- data_prepared = self._filter_strategy.feature_generate(self._data, candidate_feature)
- # 特征信息
- metric_value_dict_feature = self._filter_strategy.feature_report(self._data, candidate_feature)
- metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs)
- self.metric_value_dict = {**metric_value_dict_feature, **metric_value_dict_train}
- return self.metric_value_dict
- def generate_report(self, ):
- save_path = self._filter_strategy.data_process_config.f_get_save_path("模型报告.docx")
- Report.generate_report(self.metric_value_dict, self._model.get_template_path(), save_path=save_path)
- print(f"模型报告文件储存路径:{save_path}")
- if __name__ == "__main__":
- pass
|