# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: OnlineLearning数配置类 """ import json import os from typing import List from commom import GeneralException, f_get_datetime from config import BaseConfig from enums import ResultCodesEnum, FileEnum from init import warning_ignore class OnlineLearningConfigEntity(): def __init__(self, path_resources: str, y_column: str, project_name: str = None, lr: float = 0.01, random_state: int = 2025, batch_size: int = 64, epochs: int = 50, columns_anns: dict = {}, jupyter_print=False, save_pmml=True, stress_test=False, stress_sample_times=100, stress_bad_rate_list: List[float] = [], params_xgb={}, *args, **kwargs): self._path_resources = path_resources # 定义y变量 self._y_column = y_column # 项目名称,和缓存路径有关 self._project_name = project_name # 学习率 self._lr = lr # 学习率 self._random_state = random_state # 模型单次更新使用数据量 self._batch_size = batch_size # 最大训练轮数 self._epochs = epochs # 变量注释 self._columns_anns = columns_anns # jupyter下输出内容 self._jupyter_print = jupyter_print self._save_pmml = save_pmml self._stress_test = stress_test self._stress_sample_times = stress_sample_times self._stress_bad_rate_list = stress_bad_rate_list self._params_xgb = params_xgb if self._project_name is None or len(self._project_name) == 0: self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}") else: self._base_dir = os.path.join(BaseConfig.train_path, self._project_name) os.makedirs(self._base_dir, exist_ok=True) print(f"项目路径:【{self._base_dir}】") if self._jupyter_print: warning_ignore() @property def path_resources(self): return self._path_resources @property def y_column(self): return self._y_column @property def lr(self): return self._lr @property def random_state(self): return self._random_state @property def batch_size(self): return self._batch_size @property def epochs(self): return self._epochs @property def columns_anns(self): return self._columns_anns @property def jupyter_print(self): return self._jupyter_print @property def save_pmml(self): return self._save_pmml @property def stress_test(self): return self._stress_test @property def stress_sample_times(self): return self._stress_sample_times @property def stress_bad_rate_list(self): return self._stress_bad_rate_list @property def params_xgb(self): params = { 'objective': 'binary:logistic', 'eval_metric': 'auc', 'learning_rate': 0.1, 'max_depth': 3, 'subsample': None, 'colsample_bytree': None, 'alpha': 0, 'lambda': 1, 'num_boost_round': 100, 'early_stopping_rounds': 20, 'verbose_eval': 10, 'random_state': 2025, 'save_pmml': True, 'trees_print': False, # tree_add tree_refresh 'oltype': "refresh", 'add_columns': [] } params.update(self._params_xgb) return params @staticmethod def from_config(config_path: str): """ 从配置文件生成实体类 """ if os.path.isdir(config_path): config_path = os.path.join(config_path, FileEnum.OL_CFG.value) if os.path.exists(config_path): with open(config_path, mode="r", encoding="utf-8") as f: j = json.loads(f.read()) else: raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在") print(f"olcfg load from【{config_path}】success. ") return OnlineLearningConfigEntity(**j) def config_save(self): path = self.f_get_save_path(FileEnum.OL_CFG.value) with open(path, mode="w", encoding="utf-8") as f: j = {k.lstrip("_"): v for k, v in self.__dict__.items()} j = json.dumps(j, ensure_ascii=False) f.write(j) print(f"olcfg save to【{path}】success. ") def f_get_save_path(self, file_name: str) -> str: path = os.path.join(self._base_dir, file_name) return path if __name__ == "__main__": pass