# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: OnlineLearning数配置类 """ import json import os from typing import List from commom import GeneralException, f_get_datetime from config import BaseConfig from enums import ResultCodesEnum, FileEnum from init import warning_ignore class OnlineLearningConfigEntity(): def __init__(self, path_resources: str, y_column: str, project_name: str = None, lr: float = 0.01, batch_size: int = 64, epochs: int = 50, columns_anns: dict = {}, jupyter_print=False, stress_test=False, stress_sample_times=100, stress_bad_rate_list: List[float] = [], *args, **kwargs): self._path_resources = path_resources # 定义y变量 self._y_column = y_column # 项目名称,和缓存路径有关 self._project_name = project_name # 学习率 self._lr = lr # 模型单次更新使用数据量 self._batch_size = batch_size # 最大训练轮数 self._epochs = epochs # 变量注释 self._columns_anns = columns_anns # jupyter下输出内容 self._jupyter_print = jupyter_print # 是否开启下输出内容 self._stress_test = stress_test # jupyter下输出内容 self._stress_sample_times = stress_sample_times # jupyter下输出内容 self._stress_bad_rate_list = stress_bad_rate_list if self._project_name is None or len(self._project_name) == 0: self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}") else: self._base_dir = os.path.join(BaseConfig.train_path, self._project_name) os.makedirs(self._base_dir, exist_ok=True) print(f"项目路径:【{self._base_dir}】") if self._jupyter_print: warning_ignore() @property def path_resources(self): return self._path_resources @property def y_column(self): return self._y_column @property def lr(self): return self._lr @property def batch_size(self): return self._batch_size @property def epochs(self): return self._epochs @property def columns_anns(self): return self._columns_anns @property def jupyter_print(self): return self._jupyter_print @property def stress_test(self): return self._stress_test @property def stress_sample_times(self): return self._stress_sample_times @property def stress_bad_rate_list(self): return self._stress_bad_rate_list @staticmethod def from_config(config_path: str): """ 从配置文件生成实体类 """ if os.path.isdir(config_path): config_path = os.path.join(config_path, FileEnum.OLCFG.value) if os.path.exists(config_path): with open(config_path, mode="r", encoding="utf-8") as f: j = json.loads(f.read()) else: raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在") print(f"olcfg load from【{config_path}】success. ") return OnlineLearningConfigEntity(**j) def config_save(self): path = self.f_get_save_path(FileEnum.OLCFG.value) with open(path, mode="w", encoding="utf-8") as f: j = {k.lstrip("_"): v for k, v in self.__dict__.items()} j = json.dumps(j, ensure_ascii=False) f.write(j) print(f"olcfg save to【{path}】success. ") def f_get_save_path(self, file_name: str) -> str: path = os.path.join(self._base_dir, file_name) return path if __name__ == "__main__": pass