# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: 模型训练管道 """ from typing import Dict from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity from init import f_get_save_path from model import f_get_model from monitor.report_generate import Report class TrainPipeline(): def __init__(self, train_config: TrainConfigEntity): self._train_config = train_config model_clazz = f_get_model(self._train_config.model_type) self.model = model_clazz(self._train_config) def train(self, data: DataPreparedEntity) -> Dict[str, MetricFucEntity]: metric_value_dict = self.model.train(data) return metric_value_dict def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]): Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx")) if __name__ == "__main__": pass