# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/1
@desc: OnlineLearning数配置类
"""
import json
import os
from typing import List

from commom import GeneralException, f_get_datetime
from config import BaseConfig
from enums import ResultCodesEnum, FileEnum
from init import warning_ignore


class OnlineLearningConfigEntity():
    def __init__(self,
                 path_resources: str,
                 y_column: str,
                 project_name: str = None,
                 lr: float = 0.01,
                 batch_size: int = 64,
                 epochs: int = 50,
                 columns_anns: dict = {},
                 jupyter_print=False,
                 stress_test=False,
                 stress_sample_times=100,
                 stress_bad_rate_list: List[float] = [],
                 *args, **kwargs):

        self._path_resources = path_resources
        # 定义y变量
        self._y_column = y_column
        # 项目名称,和缓存路径有关
        self._project_name = project_name
        # 学习率
        self._lr = lr
        # 模型单次更新使用数据量
        self._batch_size = batch_size
        # 最大训练轮数
        self._epochs = epochs

        # 变量注释
        self._columns_anns = columns_anns

        # jupyter下输出内容
        self._jupyter_print = jupyter_print

        # 是否开启下输出内容
        self._stress_test = stress_test

        # jupyter下输出内容
        self._stress_sample_times = stress_sample_times

        # jupyter下输出内容
        self._stress_bad_rate_list = stress_bad_rate_list

        if self._project_name is None or len(self._project_name) == 0:
            self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
        else:
            self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
        os.makedirs(self._base_dir, exist_ok=True)
        print(f"项目路径:【{self._base_dir}】")

        if self._jupyter_print:
            warning_ignore()

    @property
    def path_resources(self):
        return self._path_resources

    @property
    def y_column(self):
        return self._y_column

    @property
    def lr(self):
        return self._lr

    @property
    def batch_size(self):
        return self._batch_size

    @property
    def epochs(self):
        return self._epochs

    @property
    def columns_anns(self):
        return self._columns_anns

    @property
    def jupyter_print(self):
        return self._jupyter_print

    @property
    def stress_test(self):
        return self._stress_test

    @property
    def stress_sample_times(self):
        return self._stress_sample_times

    @property
    def stress_bad_rate_list(self):
        return self._stress_bad_rate_list

    @staticmethod
    def from_config(config_path: str):
        """
        从配置文件生成实体类
        """
        if os.path.isdir(config_path):
            config_path = os.path.join(config_path, FileEnum.OLCFG.value)

        if os.path.exists(config_path):
            with open(config_path, mode="r", encoding="utf-8") as f:
                j = json.loads(f.read())
        else:
            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
        print(f"olcfg load from【{config_path}】success. ")
        return OnlineLearningConfigEntity(**j)

    def config_save(self):
        path = self.f_get_save_path(FileEnum.OLCFG.value)
        with open(path, mode="w", encoding="utf-8") as f:
            j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
            j = json.dumps(j, ensure_ascii=False)
            f.write(j)
        print(f"olcfg save to【{path}】success. ")

    def f_get_save_path(self, file_name: str) -> str:
        path = os.path.join(self._base_dir, file_name)
        return path


if __name__ == "__main__":
    pass