train_config_entity.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练超参数配置类
  6. """
  7. import json
  8. import os
  9. from commom import GeneralException
  10. from enums import ResultCodesEnum, ModelEnum
  11. class TrainConfigEntity():
  12. def __init__(self, model_type=str, lr: float = None):
  13. # 模型类型
  14. self._model_type = model_type
  15. # 学习率
  16. self._lr = lr
  17. # 报告模板
  18. if model_type == ModelEnum.LR.value:
  19. self._template_path = "./template/模型开发报告模板_lr.docx"
  20. @property
  21. def template_path(self):
  22. return self._template_path
  23. @property
  24. def model_type(self):
  25. return self._model_type
  26. @property
  27. def lr(self):
  28. return self._lr
  29. @staticmethod
  30. def from_config(config_path: str):
  31. """
  32. 从配置文件生成实体类
  33. """
  34. if os.path.exists(config_path):
  35. with open(config_path, mode="r", encoding="utf-8") as f:
  36. j = json.loads(f.read())
  37. else:
  38. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  39. return TrainConfigEntity(**j)
  40. if __name__ == "__main__":
  41. pass