# -*- coding:utf-8 -*-
"""
@author: yq
@time: 2022/10/24
@desc: 指标计算相关
"""

import threading

from commom import f_get_clazz_in_module, GeneralException
from enums import ResultCodesEnum
from .metric_base import MetricBase
from .metric_by_sql_general import MetricBySqlGeneral

__all__ = ['f_get_metric_clazz_dict', 'f_register_metric_func', 'MetricBase', 'MetricBySqlGeneral']

lock = threading.Lock()
metric_clazz_dict = {}


def _update_metric_clazz_dict(key, value):
    with lock:
        if key in metric_clazz_dict.keys():
            raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"自定义指标函数【{key}】已注册或出现重名")
        metric_clazz_dict[key] = value


def f_register_metric_func(clazz: MetricBase):
    if not hasattr(clazz, '_symbol') or not clazz._symbol == MetricBase._symbol:
        raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"自定义指标函数没有继承类【MetricBase】")
    _update_metric_clazz_dict(clazz.__name__, clazz)


def f_get_metric_clazz_dict():
    return metric_clazz_dict


all_classes = f_get_clazz_in_module(__import__(__name__))
for clazz in all_classes:
    if not hasattr(clazz, '_symbol') or not clazz._symbol == MetricBase._symbol:
        continue
    f_register_metric_func(clazz)

if __name__ == "__main__":
    pass