Переглянути джерело

add: 指标计算初步框架

yq 5 місяців тому
батько
коміт
affa1359fa

+ 9 - 0
data_loader/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/30
+@desc: 
+"""
+
+if __name__ == "__main__":
+    pass

+ 24 - 0
data_loader/data_loader_base.py

@@ -0,0 +1,24 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2024/1/2
+@desc: 数据加载基类
+"""
+import abc
+
+import pandas as pd
+
+
+class DataLoaderBase(metaclass=abc.ABCMeta):
+
+    @abc.abstractmethod
+    def get_connect(self):
+        pass
+
+    @abc.abstractmethod
+    def close_connect(self):
+        pass
+
+    @abc.abstractmethod
+    def get_data(self, *args, **kwargs) -> pd.DataFrame:
+        pass

+ 55 - 0
data_loader/data_loader_mysql.py

@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/31
+@desc: 
+"""
+import pandas as pd
+import pymysql
+
+from commom.logger import get_logger
+from data_loader.data_loader_base import DataLoaderBase
+
+logger = get_logger()
+
+
+class DataLoaderMysql(DataLoaderBase):
+    def __init__(self, host: str, port: int, user: str, passwd: str, db: str):
+        self.host = host
+        self.port = port
+        self.user = user
+        self.passwd = passwd
+        self.db = db
+        self.conn = None
+
+    def get_connect(self):
+        # TODO 后续改成线程池
+        if self.conn == None:
+            self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd,
+                                        db=self.db)
+        return self.conn
+
+    def close_connect(self):
+        if self.conn != None:
+            try:
+                self.conn.close()
+            except Exception as msg:
+                logger.error("关闭数据库失败:\n" + str(msg))
+            self.conn = None
+
+    def get_data(self, sql: str) -> pd.DataFrame:
+        cursor = self.get_connect().cursor()
+        cursor.execute(sql)
+        sql_results = cursor.fetchall()
+        column_desc = cursor.description
+        # 获取列名
+        columns = [column_desc[i][0] for i in range(len(column_desc))]
+        # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
+        df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
+        cursor.close()
+        self.close_connect()
+        return df
+
+
+if __name__ == "__main__":
+    pass

+ 9 - 0
init/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/31
+@desc: 模型及指标计算类初始化
+"""
+
+if __name__ == "__main__":
+    pass

+ 27 - 0
metric/metric_base.py

@@ -0,0 +1,27 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2024/1/2
+@desc: 指标计算基类
+"""
+import pandas as pd
+import abc
+
+
+class MetricBase(metaclass=abc.ABCMeta):
+
+    def __init__(self, metric_name: str, metric_code: str):
+        self.metric_name = metric_name
+        self.metric_code = metric_code
+
+    @abc.abstractmethod
+    def validate_data(self):
+        pass
+
+    @abc.abstractmethod
+    def load_data(self, *args, **kwargs) -> pd.DataFrame:
+        pass
+
+    @abc.abstractmethod
+    def calculate(self, *args, **kwargs) -> pd.DataFrame:
+        pass

+ 35 - 0
metric/metric_by_sql_general.py

@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/31
+@desc: sql直接统计指标直出
+"""
+import pandas as pd
+
+from data_loader.data_loader_base import DataLoaderBase
+from metric.metric_base import MetricBase
+
+
+class MetricBySqlGeneral(MetricBase):
+
+    def __init__(self, metric_name: str, metric_code: str):
+        super().__init__(metric_name, metric_code)
+
+    def validate_data(self):
+        pass
+
+    def load_data(self, data_loader: DataLoaderBase, sql: str) -> pd.DataFrame:
+        data = data_loader.get_data(sql)
+        return data
+
+    def calculate(self, data_loader: DataLoaderBase, sql: str) -> pd.DataFrame:
+        return self.load_data(data_loader, sql)
+
+
+if __name__ == "__main__":
+    from data_loader.data_loader_mysql import DataLoaderMysql
+
+    data_loader = DataLoaderMysql(host="101.126.81.2", port=18001, user="root", passwd="Cqrcb2024", db="test")
+    metric_clzz = MetricBySqlGeneral(metric_name="auc", metric_code="auc")
+    metric = metric_clzz.calculate(data_loader, "select * from test.t1")
+    print(metric.head(5))

+ 5 - 1
model/model_base.py

@@ -2,7 +2,7 @@
 """
 @author: yq
 @time: 2024/1/2
-@desc: 
+@desc: 模型基类
 """
 import pandas as pd
 import abc
@@ -13,6 +13,10 @@ class ModelBase(metaclass=abc.ABCMeta):
     def train(self, data: pd.DataFrame):
         pass
 
+    @abc.abstractmethod
+    def predict_prob(self, x: pd.DataFrame):
+        pass
+
     @abc.abstractmethod
     def predict(self, x: pd.DataFrame):
         pass

+ 10 - 0
report/__init__.py

@@ -0,0 +1,10 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/10/24
+@desc: 报告生成相关
+"""
+
+
+if __name__ == "__main__":
+    pass

+ 1 - 1
requirements.txt

@@ -1,2 +1,2 @@
-psutil==5.9.5
+pymysql==1.0.2
 

+ 9 - 0
task/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/31
+@desc: 任务相关
+"""
+
+if __name__ == "__main__":
+    pass