123456789101112131415161718192021222324252627282930 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: 模型训练管道
- """
- from typing import Dict
- from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity
- from init import f_get_save_path
- from model import f_get_model
- from monitor.report_generate import Report
- class TrainPipeline():
- def __init__(self, train_config: TrainConfigEntity):
- self._train_config = train_config
- model_clazz = f_get_model(self._train_config.model_type)
- self.model = model_clazz(self._train_config)
- def train(self, data: DataPreparedEntity) -> Dict[str, MetricFucEntity]:
- metric_value_dict = self.model.train(data)
- return metric_value_dict
- def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]):
- Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx"))
- if __name__ == "__main__":
- pass
|