ol_config_entity.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: OnlineLearning数配置类
  6. """
  7. import json
  8. import os
  9. from typing import List
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum, FileEnum
  13. from init import warning_ignore
  14. class OnlineLearningConfigEntity():
  15. def __init__(self,
  16. path_resources: str,
  17. y_column: str,
  18. project_name: str = None,
  19. lr: float = 0.01,
  20. random_state: int = 2025,
  21. batch_size: int = 64,
  22. epochs: int = 50,
  23. columns_anns: dict = {},
  24. jupyter_print=False,
  25. save_pmml=True,
  26. stress_test=False,
  27. stress_sample_times=100,
  28. stress_bad_rate_list: List[float] = [],
  29. *args, **kwargs):
  30. self._path_resources = path_resources
  31. # 定义y变量
  32. self._y_column = y_column
  33. # 项目名称,和缓存路径有关
  34. self._project_name = project_name
  35. # 学习率
  36. self._lr = lr
  37. # 学习率
  38. self._random_state = random_state
  39. # 模型单次更新使用数据量
  40. self._batch_size = batch_size
  41. # 最大训练轮数
  42. self._epochs = epochs
  43. # 变量注释
  44. self._columns_anns = columns_anns
  45. # jupyter下输出内容
  46. self._jupyter_print = jupyter_print
  47. self._save_pmml = save_pmml
  48. # 是否开启下输出内容
  49. self._stress_test = stress_test
  50. # jupyter下输出内容
  51. self._stress_sample_times = stress_sample_times
  52. # jupyter下输出内容
  53. self._stress_bad_rate_list = stress_bad_rate_list
  54. if self._project_name is None or len(self._project_name) == 0:
  55. self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  56. else:
  57. self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
  58. os.makedirs(self._base_dir, exist_ok=True)
  59. print(f"项目路径:【{self._base_dir}】")
  60. if self._jupyter_print:
  61. warning_ignore()
  62. @property
  63. def path_resources(self):
  64. return self._path_resources
  65. @property
  66. def y_column(self):
  67. return self._y_column
  68. @property
  69. def lr(self):
  70. return self._lr
  71. @property
  72. def random_state(self):
  73. return self._random_state
  74. @property
  75. def batch_size(self):
  76. return self._batch_size
  77. @property
  78. def epochs(self):
  79. return self._epochs
  80. @property
  81. def columns_anns(self):
  82. return self._columns_anns
  83. @property
  84. def jupyter_print(self):
  85. return self._jupyter_print
  86. @property
  87. def save_pmml(self):
  88. return self._save_pmml
  89. @property
  90. def stress_test(self):
  91. return self._stress_test
  92. @property
  93. def stress_sample_times(self):
  94. return self._stress_sample_times
  95. @property
  96. def stress_bad_rate_list(self):
  97. return self._stress_bad_rate_list
  98. @staticmethod
  99. def from_config(config_path: str):
  100. """
  101. 从配置文件生成实体类
  102. """
  103. if os.path.isdir(config_path):
  104. config_path = os.path.join(config_path, FileEnum.OLCFG.value)
  105. if os.path.exists(config_path):
  106. with open(config_path, mode="r", encoding="utf-8") as f:
  107. j = json.loads(f.read())
  108. else:
  109. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  110. print(f"olcfg load from【{config_path}】success. ")
  111. return OnlineLearningConfigEntity(**j)
  112. def config_save(self):
  113. path = self.f_get_save_path(FileEnum.OLCFG.value)
  114. with open(path, mode="w", encoding="utf-8") as f:
  115. j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
  116. j = json.dumps(j, ensure_ascii=False)
  117. f.write(j)
  118. print(f"olcfg save to【{path}】success. ")
  119. def f_get_save_path(self, file_name: str) -> str:
  120. path = os.path.join(self._base_dir, file_name)
  121. return path
  122. if __name__ == "__main__":
  123. pass