12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: 模型训练超参数配置类
- """
- import json
- import os
- from commom import GeneralException
- from enums import ResultCodesEnum
- class TrainConfigEntity():
- def __init__(self, lr: float = None, *args, **kwargs):
- # 学习率
- self._lr = lr
- # 该函数需要去继承
- self.f_get_save_path = None
- @property
- def lr(self):
- return self._lr
- def set_save_path_func(self, f):
- self.f_get_save_path = f
- @staticmethod
- def from_config(config_path: str):
- """
- 从配置文件生成实体类
- """
- if os.path.exists(config_path):
- with open(config_path, mode="r", encoding="utf-8") as f:
- j = json.loads(f.read())
- else:
- raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
- return TrainConfigEntity(**j)
- if __name__ == "__main__":
- pass
|