train.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. from entitys import DataSplitEntity
  8. from feature.filter_strategy_base import FilterStrategyBase
  9. from init import f_get_save_path
  10. from model import ModelBase
  11. from monitor.report_generate import Report
  12. class TrainPipeline():
  13. def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity):
  14. self._filter_strategy = filter_strategy
  15. self._model = model
  16. self._data = data
  17. def train(self, ):
  18. # 处理数据,获取候选特征
  19. candidate_feature = self._filter_strategy.filter(self._data)
  20. # 生成训练数据
  21. data_prepared = self._filter_strategy.feature_generate(self._data, candidate_feature)
  22. # 特征信息
  23. metric_value_dict_feature = self._filter_strategy.feature_report(self._data, candidate_feature)
  24. metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs)
  25. self.metric_value_dict = {**metric_value_dict_feature, **metric_value_dict_train}
  26. def generate_report(self, ):
  27. Report.generate_report(self.metric_value_dict, self._model.get_template_path(),
  28. save_path=f_get_save_path("模型报告.docx"))
  29. if __name__ == "__main__":
  30. pass