# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/1
@desc: 模型训练超参数配置类
"""
import json
import os

from commom import GeneralException
from enums import ResultCodesEnum


class TrainConfigEntity():
    def __init__(self, lr: float = None, *args, **kwargs):
        # 学习率
        self._lr = lr
        # 该函数需要去继承
        self.f_get_save_path = None

    @property
    def lr(self):
        return self._lr

    def set_save_path_func(self, f):
        self.f_get_save_path = f

    @staticmethod
    def from_config(config_path: str):
        """
        从配置文件生成实体类
        """
        if os.path.exists(config_path):
            with open(config_path, mode="r", encoding="utf-8") as f:
                j = json.loads(f.read())
        else:
            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")

        return TrainConfigEntity(**j)


if __name__ == "__main__":
    pass