123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: 模型训练超参数配置类
- """
- import json
- import os
- from commom import GeneralException
- from enums import ResultCodesEnum, ModelEnum
- class TrainConfigEntity():
- def __init__(self, model_type=str, lr: float = None):
- # 模型类型
- self._model_type = model_type
- # 学习率
- self._lr = lr
- # 报告模板
- if model_type == ModelEnum.LR.value:
- self._template_path = "./template/模型开发报告模板_lr.docx"
- @property
- def template_path(self):
- return self._template_path
- @property
- def model_type(self):
- return self._model_type
- @property
- def lr(self):
- return self._lr
- @staticmethod
- def from_config(config_path: str):
- """
- 从配置文件生成实体类
- """
- if os.path.exists(config_path):
- with open(config_path, mode="r", encoding="utf-8") as f:
- j = json.loads(f.read())
- else:
- raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
- return TrainConfigEntity(**j)
- if __name__ == "__main__":
- pass
|