소스 검색

modify: 优化指标计算框架逻辑

yq 5 달 전
부모
커밋
2540bca550

+ 4 - 0
commom/__init__.py

@@ -4,3 +4,7 @@
 @time: 2021/11/9
 @desc: 
 """
+from .utils import f_get_clazz_in_module
+
+
+__all__ = ['f_get_clazz_in_module', ]

+ 11 - 1
commom/utils.py

@@ -2,5 +2,15 @@
 """
 @author: yq
 @time: 2023/12/28
-@desc: 
+@desc:  各种工具类
 """
+
+import inspect
+
+
+def f_get_clazz_in_module(module):
+    classes = []
+    for name, member in inspect.getmembers(module):
+        if inspect.isclass(member):
+            classes.append(member)
+    return classes

+ 5 - 2
entitys/__init__.py

@@ -6,9 +6,12 @@
 """
 from .data_feaure_entity import DataFeatureEntity
 from .db_config_entity import DbConfigEntity
-from .metric_entity import MetricTrainEntity
+from .metric_config_entity import MetricConfigEntity
+from .metric_entity import MetricTrainEntity, MetricFucEntity
+from .monitor_config_entity import ModelMonitorConfigEntity
 
-__all__ = ['DataFeatureEntity', 'DbConfigEntity', 'MetricTrainEntity']
+__all__ = ['DataFeatureEntity', 'DbConfigEntity', 'MetricTrainEntity', 'ModelMonitorConfigEntity', 'MetricConfigEntity',
+           'MetricFucEntity']
 
 if __name__ == "__main__":
     pass

+ 36 - 0
entitys/metric_config_entity.py

@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 指标配置
+"""
+
+
+class MetricConfigEntity():
+    def __init__(self, metric_code: str, metric_func: str, *args, **kwargs):
+        self._args = args
+        self._kwargs = kwargs
+        # metric_code 用于填充模板时查找
+        self._metric_code = metric_code
+        # metric_func 用于查找对应的指标计算函数
+        self._metric_func = metric_func
+
+    @property
+    def args(self):
+        return self._args
+
+    @property
+    def kwargs(self):
+        return self._kwargs
+
+    @property
+    def metric_code(self):
+        return self._metric_code
+
+    @property
+    def metric_func(self):
+        return self._metric_func
+
+
+if __name__ == "__main__":
+    pass

+ 31 - 1
entitys/metric_entity.py

@@ -2,10 +2,16 @@
 """
 @author: yq
 @time: 2024/11/1
-@desc: 
+@desc:  常用指标实体集合
 """
+import pandas as pd
+
 
 class MetricTrainEntity():
+    """
+    模型训练结果指标类
+    """
+
     def __init__(self, auc: float, ks: float):
         self._auc = auc
         self._ks = ks
@@ -18,5 +24,29 @@ class MetricTrainEntity():
     def ks(self):
         return self._ks
 
+
+class MetricFucEntity():
+    """
+    指标计算函数结果类
+    """
+
+    def __init__(self, table: pd.DataFrame = None, value: str = None, image_path: str = None):
+        self._table = table
+        self._value = value
+        self._image_path = image_path
+
+    @property
+    def table(self):
+        return self._table
+
+    @property
+    def value(self):
+        return self._value
+
+    @property
+    def image_path(self):
+        return self._image_path
+
+
 if __name__ == "__main__":
     pass

+ 45 - 0
entitys/monitor_config_entity.py

@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 指标监控配置
+"""
+from typing import List, Dict
+
+from entitys import MetricConfigEntity
+from metrics import f_get_metric_clazz_dict, MetricBase
+
+
+class ModelMonitorConfigEntity():
+
+    def __init__(self, metric_config_list: List[MetricConfigEntity], template_path: str):
+        self._template_path = template_path
+        self._metric_clazz_dict = f_get_metric_clazz_dict()
+        self._metric_dict: Dict[str, MetricBase] = self._init_metric(metric_config_list)
+
+    @property
+    def template_path(self):
+        return self._template_path
+
+    @property
+    def metric_dict(self):
+        return self._metric_dict
+
+    def _init_metric(self, metric_config_list: List[MetricConfigEntity]) -> Dict[str, MetricBase]:
+        metric_dict = {}
+        for metric_config in metric_config_list:
+            metric_func_name = metric_config.metric_func
+            metric_code = metric_config.metric_code
+            # 指标函数不存在
+            if metric_func_name not in self._metric_clazz_dict.keys():
+                pass
+            # 指标code不唯一
+            if metric_code in metric_dict.keys():
+                pass
+            metric_clazz = self._metric_clazz_dict[metric_func_name]
+            metric_dict[metric_code] = metric_clazz(*metric_config.args, **metric_config.kwargs)
+        return metric_dict
+
+
+if __name__ == "__main__":
+    pass

+ 2 - 2
metric_test.py

@@ -6,11 +6,11 @@
 """
 from data import DataLoaderMysql
 from entitys import DbConfigEntity
-from metrics import MetricBySqlGeneral
+from metrics import MetricBySqlGeneral, f_get_metric_clazz_dict
 
 if __name__ == "__main__":
     db_config = DbConfigEntity(host="101.126.81.2", port=18001, user="root", passwd="Cqrcb2024", db="test")
     data_loader = DataLoaderMysql(db_config)
-    metric_clzz = MetricBySqlGeneral(metric_name="auc", metric_code="auc")
+    metric_clzz = MetricBySqlGeneral()
     metric = metric_clzz.calculate(data_loader, "select * from test.t1")
     print(metric.head(5))

+ 13 - 1
metrics/__init__.py

@@ -4,10 +4,22 @@
 @time: 2022/10/24
 @desc: 指标计算相关
 """
+from typing import Dict
+
+from commom import f_get_clazz_in_module
 from .metric_base import MetricBase
 from .metric_by_sql_general import MetricBySqlGeneral
 
-__all__ = ['MetricBase', 'MetricBySqlGeneral']
+__all__ = ['f_get_metric_clazz_dict', 'MetricBase', 'MetricBySqlGeneral']
+
+
+def f_get_metric_clazz_dict():
+    all_classes = f_get_clazz_in_module(__import__(__name__))
+    metric_clazz_dict = {}
+    for cls in all_classes:
+        metric_clazz_dict[cls.__name__] = cls
+    return metric_clazz_dict
+
 
 if __name__ == "__main__":
     pass

+ 5 - 6
metrics/metric_base.py

@@ -4,15 +4,14 @@
 @time: 2024/1/2
 @desc: 指标计算基类
 """
-import pandas as pd
 import abc
 
+import pandas as pd
 
-class MetricBase(metaclass=abc.ABCMeta):
+from entitys import MetricFucEntity
 
-    def __init__(self, metric_name: str, metric_code: str):
-        self.metric_name = metric_name
-        self.metric_code = metric_code
+
+class MetricBase(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def validate_data(self):
@@ -23,5 +22,5 @@ class MetricBase(metaclass=abc.ABCMeta):
         pass
 
     @abc.abstractmethod
-    def calculate(self, *args, **kwargs) -> pd.DataFrame:
+    def calculate(self, *args, **kwargs) -> MetricFucEntity:
         pass

+ 9 - 7
metrics/metric_by_sql_general.py

@@ -7,20 +7,22 @@
 import pandas as pd
 
 from data import DataLoaderBase
+from entitys import MetricFucEntity
 from .metric_base import MetricBase
 
 
 class MetricBySqlGeneral(MetricBase):
 
-    def __init__(self, metric_name: str, metric_code: str):
-        super().__init__(metric_name, metric_code)
+    def __init__(self, data_loader: DataLoaderBase = None, sql: str = None, **kwargs):
+        self._data_loader = data_loader
+        self._sql = sql
 
     def validate_data(self):
         pass
 
-    def load_data(self, data_loader: DataLoaderBase, sql: str) -> pd.DataFrame:
-        data = data_loader.get_data(sql)
-        return data
+    def load_data(self, ) -> pd.DataFrame:
+        return self._data_loader.get_data(self._sql)
+
+    def calculate(self, ) -> MetricFucEntity:
+        return MetricFucEntity(table=self.load_data())
 
-    def calculate(self, data_loader: DataLoaderBase, sql: str) -> pd.DataFrame:
-        return self.load_data(data_loader, sql)

+ 2 - 1
model/__init__.py

@@ -6,8 +6,9 @@
 """
 
 from .model_base import ModelBase
+from .model_lr import ModelLr
 
-__all__ = ['ModelBase',]
+__all__ = ['ModelBase', 'ModelLr']
 
 if __name__ == "__main__":
     pass

+ 1 - 1
model/model_lr.py

@@ -8,7 +8,7 @@ import pandas as pd
 from sklearn.linear_model import LogisticRegression
 
 from entitys import DataFeatureEntity, MetricTrainEntity
-from model_base import ModelBase
+from .model_base import ModelBase
 
 
 class ModelLr(ModelBase):

+ 4 - 0
monitor/__init__.py

@@ -5,5 +5,9 @@
 @desc: 指标监控
 """
 
+from .monitor_model import MonitorModel
+
+__all__ = ['MonitorModel']
+
 if __name__ == "__main__":
     pass

+ 33 - 0
monitor/monitor_model.py

@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 监控报告
+"""
+import threading
+from typing import Dict
+
+from entitys import ModelMonitorConfigEntity, MetricFucEntity
+
+
+class MonitorModel():
+
+    def __init__(self, model_monitor_config: ModelMonitorConfigEntity):
+        self._model_monitor_config = model_monitor_config
+        self.lock = threading.Lock()
+        self._metric_value_dict: Dict[str, MetricFucEntity] = {}
+
+    def _update_metric_value_dict(self, key, value):
+        with self.lock:
+            self._metric_value_dict[key] = value
+
+    #  TODO 多线程计算指标
+    def calculate_metric(self):
+        metric_dict = self._model_monitor_config.metric_dict
+        for metric_code, metric_clazz in metric_dict.items():
+            metric_value = metric_clazz.calculate()
+            self._update_metric_value_dict(metric_code, metric_value)
+
+
+if __name__ == "__main__":
+    pass