123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: OnlineLearning数配置类
- """
- import json
- import os
- from typing import List
- from commom import GeneralException, f_get_datetime
- from config import BaseConfig
- from enums import ResultCodesEnum, FileEnum
- from init import warning_ignore
- class OnlineLearningConfigEntity():
- def __init__(self,
- path_resources: str,
- y_column: str,
- project_name: str = None,
- lr: float = 0.01,
- batch_size: int = 64,
- epochs: int = 50,
- columns_anns: dict = {},
- jupyter_print=False,
- stress_test=False,
- stress_sample_times=100,
- stress_bad_rate_list: List[float] = [],
- *args, **kwargs):
- self._path_resources = path_resources
- # 定义y变量
- self._y_column = y_column
- # 项目名称,和缓存路径有关
- self._project_name = project_name
- # 学习率
- self._lr = lr
- # 模型单次更新使用数据量
- self._batch_size = batch_size
- # 最大训练轮数
- self._epochs = epochs
- # 变量注释
- self._columns_anns = columns_anns
- # jupyter下输出内容
- self._jupyter_print = jupyter_print
- # 是否开启下输出内容
- self._stress_test = stress_test
- # jupyter下输出内容
- self._stress_sample_times = stress_sample_times
- # jupyter下输出内容
- self._stress_bad_rate_list = stress_bad_rate_list
- if self._project_name is None or len(self._project_name) == 0:
- self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
- else:
- self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
- os.makedirs(self._base_dir, exist_ok=True)
- print(f"项目路径:【{self._base_dir}】")
- if self._jupyter_print:
- warning_ignore()
- @property
- def path_resources(self):
- return self._path_resources
- @property
- def y_column(self):
- return self._y_column
- @property
- def lr(self):
- return self._lr
- @property
- def batch_size(self):
- return self._batch_size
- @property
- def epochs(self):
- return self._epochs
- @property
- def columns_anns(self):
- return self._columns_anns
- @property
- def jupyter_print(self):
- return self._jupyter_print
- @property
- def stress_test(self):
- return self._stress_test
- @property
- def stress_sample_times(self):
- return self._stress_sample_times
- @property
- def stress_bad_rate_list(self):
- return self._stress_bad_rate_list
- @staticmethod
- def from_config(config_path: str):
- """
- 从配置文件生成实体类
- """
- if os.path.isdir(config_path):
- config_path = os.path.join(config_path, FileEnum.OLCFG.value)
- if os.path.exists(config_path):
- with open(config_path, mode="r", encoding="utf-8") as f:
- j = json.loads(f.read())
- else:
- raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
- print(f"olcfg load from【{config_path}】success. ")
- return OnlineLearningConfigEntity(**j)
- def config_save(self):
- path = self.f_get_save_path(FileEnum.OLCFG.value)
- with open(path, mode="w", encoding="utf-8") as f:
- j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
- j = json.dumps(j, ensure_ascii=False)
- f.write(j)
- print(f"olcfg save to【{path}】success. ")
- def f_get_save_path(self, file_name: str) -> str:
- path = os.path.join(self._base_dir, file_name)
- return path
- if __name__ == "__main__":
- pass
|