train.py 906 B

123456789101112131415161718192021222324252627282930
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. from typing import Dict
  8. from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity
  9. from init import f_get_save_path
  10. from model import f_get_model
  11. from monitor.report_generate import Report
  12. class TrainPipeline():
  13. def __init__(self, train_config: TrainConfigEntity):
  14. self._train_config = train_config
  15. model_clazz = f_get_model(self._train_config.model_type)
  16. self.model = model_clazz(self._train_config)
  17. def train(self, data: DataPreparedEntity):
  18. metric_train = self.model.train(data)
  19. print(metric_train)
  20. def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]):
  21. Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx"))
  22. if __name__ == "__main__":
  23. pass