train_config_entity.py 1001 B

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练超参数配置类
  6. """
  7. import json
  8. import os
  9. from commom import GeneralException
  10. from enums import ResultCodesEnum
  11. class TrainConfigEntity():
  12. def __init__(self, lr: float = None, *args, **kwargs):
  13. # 学习率
  14. self._lr = lr
  15. # 该函数需要去继承
  16. self.f_get_save_path = None
  17. @property
  18. def lr(self):
  19. return self._lr
  20. def set_save_path_func(self, f):
  21. self.f_get_save_path = f
  22. @staticmethod
  23. def from_config(config_path: str):
  24. """
  25. 从配置文件生成实体类
  26. """
  27. if os.path.exists(config_path):
  28. with open(config_path, mode="r", encoding="utf-8") as f:
  29. j = json.loads(f.read())
  30. else:
  31. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  32. return TrainConfigEntity(**j)
  33. if __name__ == "__main__":
  34. pass