ol_config_entity.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: OnlineLearning数配置类
  6. """
  7. import json
  8. import os
  9. from typing import List
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum, FileEnum
  13. from init import warning_ignore
  14. class OnlineLearningConfigEntity():
  15. def __init__(self,
  16. path_resources: str,
  17. y_column: str,
  18. project_name: str = None,
  19. lr: float = 0.01,
  20. batch_size: int = 64,
  21. epochs: int = 50,
  22. columns_anns: dict = {},
  23. jupyter_print=False,
  24. stress_test=False,
  25. stress_sample_times=100,
  26. stress_bad_rate_list: List[float] = [],
  27. *args, **kwargs):
  28. self._path_resources = path_resources
  29. # 定义y变量
  30. self._y_column = y_column
  31. # 项目名称,和缓存路径有关
  32. self._project_name = project_name
  33. # 学习率
  34. self._lr = lr
  35. # 模型单次更新使用数据量
  36. self._batch_size = batch_size
  37. # 最大训练轮数
  38. self._epochs = epochs
  39. # 变量注释
  40. self._columns_anns = columns_anns
  41. # jupyter下输出内容
  42. self._jupyter_print = jupyter_print
  43. # 是否开启下输出内容
  44. self._stress_test = stress_test
  45. # jupyter下输出内容
  46. self._stress_sample_times = stress_sample_times
  47. # jupyter下输出内容
  48. self._stress_bad_rate_list = stress_bad_rate_list
  49. if self._project_name is None or len(self._project_name) == 0:
  50. self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  51. else:
  52. self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
  53. os.makedirs(self._base_dir, exist_ok=True)
  54. print(f"项目路径:【{self._base_dir}】")
  55. if self._jupyter_print:
  56. warning_ignore()
  57. @property
  58. def path_resources(self):
  59. return self._path_resources
  60. @property
  61. def y_column(self):
  62. return self._y_column
  63. @property
  64. def lr(self):
  65. return self._lr
  66. @property
  67. def batch_size(self):
  68. return self._batch_size
  69. @property
  70. def epochs(self):
  71. return self._epochs
  72. @property
  73. def columns_anns(self):
  74. return self._columns_anns
  75. @property
  76. def jupyter_print(self):
  77. return self._jupyter_print
  78. @property
  79. def stress_test(self):
  80. return self._stress_test
  81. @property
  82. def stress_sample_times(self):
  83. return self._stress_sample_times
  84. @property
  85. def stress_bad_rate_list(self):
  86. return self._stress_bad_rate_list
  87. @staticmethod
  88. def from_config(config_path: str):
  89. """
  90. 从配置文件生成实体类
  91. """
  92. if os.path.isdir(config_path):
  93. config_path = os.path.join(config_path, FileEnum.OLCFG.value)
  94. if os.path.exists(config_path):
  95. with open(config_path, mode="r", encoding="utf-8") as f:
  96. j = json.loads(f.read())
  97. else:
  98. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  99. print(f"olcfg load from【{config_path}】success. ")
  100. return OnlineLearningConfigEntity(**j)
  101. def config_save(self):
  102. path = self.f_get_save_path(FileEnum.OLCFG.value)
  103. with open(path, mode="w", encoding="utf-8") as f:
  104. j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
  105. j = json.dumps(j, ensure_ascii=False)
  106. f.write(j)
  107. print(f"olcfg save to【{path}】success. ")
  108. def f_get_save_path(self, file_name: str) -> str:
  109. path = os.path.join(self._base_dir, file_name)
  110. return path
  111. if __name__ == "__main__":
  112. pass