train.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. from typing import Dict
  8. from entitys import DataSplitEntity, MetricFucEntity
  9. from feature.filter_strategy_base import FilterStrategyBase
  10. from init import init
  11. from model import ModelBase
  12. from monitor.report_generate import Report
  13. init()
  14. class TrainPipeline():
  15. def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity):
  16. self._filter_strategy = filter_strategy
  17. self._model = model
  18. self._data = data
  19. self._model._train_config.set_save_path_func(self._filter_strategy.data_process_config.f_get_save_path)
  20. self._model._data_process_config = self._filter_strategy.data_process_config
  21. def train(self, ) -> Dict[str, MetricFucEntity]:
  22. # 处理数据,获取候选特征
  23. candidate_feature, numeric_candidate_dict_all = self._filter_strategy.filter(self._data)
  24. # 生成训练数据
  25. data_prepared = self._filter_strategy.feature_generate(self._data, candidate_feature)
  26. # 特征信息
  27. metric_value_dict_feature = self._filter_strategy.feature_report(self._data, candidate_feature, numeric_candidate_dict_all)
  28. metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs)
  29. self.metric_value_dict = {**metric_value_dict_feature, **metric_value_dict_train}
  30. return self.metric_value_dict
  31. def generate_report(self, ):
  32. save_path = self._filter_strategy.data_process_config.f_get_save_path("模型报告.docx")
  33. Report.generate_report(self.metric_value_dict, self._model.get_template_path(), save_path=save_path)
  34. print(f"模型报告文件储存路径:{save_path}")
  35. if __name__ == "__main__":
  36. pass