ol_config_entity.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: OnlineLearning数配置类
  6. """
  7. import json
  8. import os
  9. from typing import List
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum, FileEnum
  13. from init import warning_ignore
  14. class OnlineLearningConfigEntity():
  15. def __init__(self,
  16. path_resources: str,
  17. y_column: str,
  18. project_name: str = None,
  19. lr: float = 0.01,
  20. random_state: int = 2025,
  21. batch_size: int = 64,
  22. epochs: int = 50,
  23. columns_anns: dict = {},
  24. jupyter_print=False,
  25. stress_test=False,
  26. stress_sample_times=100,
  27. stress_bad_rate_list: List[float] = [],
  28. *args, **kwargs):
  29. self._path_resources = path_resources
  30. # 定义y变量
  31. self._y_column = y_column
  32. # 项目名称,和缓存路径有关
  33. self._project_name = project_name
  34. # 学习率
  35. self._lr = lr
  36. # 学习率
  37. self._random_state = random_state
  38. # 模型单次更新使用数据量
  39. self._batch_size = batch_size
  40. # 最大训练轮数
  41. self._epochs = epochs
  42. # 变量注释
  43. self._columns_anns = columns_anns
  44. # jupyter下输出内容
  45. self._jupyter_print = jupyter_print
  46. # 是否开启下输出内容
  47. self._stress_test = stress_test
  48. # jupyter下输出内容
  49. self._stress_sample_times = stress_sample_times
  50. # jupyter下输出内容
  51. self._stress_bad_rate_list = stress_bad_rate_list
  52. if self._project_name is None or len(self._project_name) == 0:
  53. self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  54. else:
  55. self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
  56. os.makedirs(self._base_dir, exist_ok=True)
  57. print(f"项目路径:【{self._base_dir}】")
  58. if self._jupyter_print:
  59. warning_ignore()
  60. @property
  61. def path_resources(self):
  62. return self._path_resources
  63. @property
  64. def y_column(self):
  65. return self._y_column
  66. @property
  67. def lr(self):
  68. return self._lr
  69. @property
  70. def random_state(self):
  71. return self._random_state
  72. @property
  73. def batch_size(self):
  74. return self._batch_size
  75. @property
  76. def epochs(self):
  77. return self._epochs
  78. @property
  79. def columns_anns(self):
  80. return self._columns_anns
  81. @property
  82. def jupyter_print(self):
  83. return self._jupyter_print
  84. @property
  85. def stress_test(self):
  86. return self._stress_test
  87. @property
  88. def stress_sample_times(self):
  89. return self._stress_sample_times
  90. @property
  91. def stress_bad_rate_list(self):
  92. return self._stress_bad_rate_list
  93. @staticmethod
  94. def from_config(config_path: str):
  95. """
  96. 从配置文件生成实体类
  97. """
  98. if os.path.isdir(config_path):
  99. config_path = os.path.join(config_path, FileEnum.OLCFG.value)
  100. if os.path.exists(config_path):
  101. with open(config_path, mode="r", encoding="utf-8") as f:
  102. j = json.loads(f.read())
  103. else:
  104. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  105. print(f"olcfg load from【{config_path}】success. ")
  106. return OnlineLearningConfigEntity(**j)
  107. def config_save(self):
  108. path = self.f_get_save_path(FileEnum.OLCFG.value)
  109. with open(path, mode="w", encoding="utf-8") as f:
  110. j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
  111. j = json.dumps(j, ensure_ascii=False)
  112. f.write(j)
  113. print(f"olcfg save to【{path}】success. ")
  114. def f_get_save_path(self, file_name: str) -> str:
  115. path = os.path.join(self._base_dir, file_name)
  116. return path
  117. if __name__ == "__main__":
  118. pass