Răsfoiți Sursa

add: DataLoaderHive

yq 4 luni în urmă
părinte
comite
94a4fa69d8

+ 1 - 1
config/data_process_config_template.json

@@ -3,7 +3,7 @@
   "bin_search_interval": 0.05,
   "feature_search_strategy": "iv",
   "x_candidate_num": 10,
-  "special_values": null,
+  "special_values": {"age_in_years": [36]},
   "breaks_list": {
     "duration_in_month": [13, 17,  47],
     "credit_amount": [2001, 3000, 4000, 5000,  10000],

+ 7 - 0
config/hive_config.json

@@ -0,0 +1,7 @@
+{
+  "host": "101.126.81.2",
+  "port": 18004,
+  "user": null,
+  "passwd": null,
+  "db": "default"
+}

+ 2 - 1
data/__init__.py

@@ -7,9 +7,10 @@
 from .insight.data_explore import DataExplore
 from .loader.data_loader_base import DataLoaderBase
 from .loader.data_loader_excel import DataLoaderExcel
+from .loader.data_loader_hive import DataLoaderHive
 from .loader.data_loader_mysql import DataLoaderMysql
 
-__all__ = ['DataLoaderBase', 'DataLoaderMysql', 'DataLoaderExcel', 'DataExplore']
+__all__ = ['DataLoaderBase', 'DataLoaderMysql', 'DataLoaderHive', 'DataLoaderExcel', 'DataExplore']
 
 if __name__ == "__main__":
     pass

+ 47 - 0
data/loader/data_loader_hive.py

@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/31
+@desc: 
+"""
+import pandas as pd
+from pyhive import hive
+from commom import get_logger
+from entitys import DbConfigEntity
+from .data_loader_base import DataLoaderBase
+
+logger = get_logger()
+
+
+class DataLoaderHive(DataLoaderBase):
+    def __init__(self, db_config: DbConfigEntity):
+        self.db_config = db_config
+        self.conn = None
+
+    def get_connect(self):
+        #  TODO 后续改成线程池
+        if self.conn == None:
+            self.conn = hive.connect(host=self.db_config.host, port=self.db_config.port, auth=self.db_config.user,
+                                        password=self.db_config.passwd, database=self.db_config.db)
+        return self.conn
+
+    def close_connect(self):
+        if self.conn != None:
+            try:
+                self.conn.close()
+            except Exception as msg:
+                logger.error("关闭数据库失败:\n" + str(msg))
+            self.conn = None
+
+    def get_data(self, sql: str) -> pd.DataFrame:
+        cursor = self.get_connect().cursor()
+        cursor.execute(sql)
+        sql_results = cursor.fetchall()
+        column_desc = cursor.description
+        # 获取列名
+        columns = [column_desc[i][0] for i in range(len(column_desc))]
+        # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
+        df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
+        cursor.close()
+        self.close_connect()
+        return df

+ 10 - 5
metric_test.py

@@ -6,7 +6,7 @@
 """
 import pandas as pd
 
-from data import DataLoaderMysql, DataLoaderBase
+from data import DataLoaderMysql, DataLoaderBase, DataLoaderHive
 from entitys import DbConfigEntity, MetricFucEntity
 from metrics import MetricBase
 from monitor import MonitorMetric
@@ -30,7 +30,12 @@ class A(MetricBase):
 
 if __name__ == "__main__":
     # f_register_metric_func(A)
-    data_loader = DataLoaderMysql(DbConfigEntity.from_config("./config/mysql_config.json"))
-    monitor_metric = MonitorMetric("./config/model_monitor_config_template.json")
-    monitor_metric.calculate_metric(data_loader=data_loader)
-    monitor_metric.generate_report()
+
+    data_loader = DataLoaderHive(DbConfigEntity.from_config("./config/hive_config.json"))
+    df = data_loader.get_data("select * from pokes")
+    print(df.head())
+
+    # data_loader = DataLoaderMysql(DbConfigEntity.from_config("./config/mysql_config.json"))
+    # monitor_metric = MonitorMetric("./config/model_monitor_config_template.json")
+    # monitor_metric.calculate_metric(data_loader=data_loader)
+    # monitor_metric.generate_report()

+ 4 - 0
requirements-py310.txt

@@ -6,4 +6,8 @@ dataframe_image==0.1.14
 gradio==5.8.0
 matplotlib==3.9.3
 numpy==1.26.4
+pandas==1.5.3
 scikit-learn==1.1.3
+pyhive==0.7.0
+thrift==0.21.0
+thrift-sasl==0.4.3

+ 8 - 1
requirements.txt

@@ -2,4 +2,11 @@ pymysql==1.0.2
 python-docx==0.8.11
 xlrd==1.2.0
 scorecardpy==0.1.9.7
-dataframe_image==0.1.14
+dataframe_image==0.1.14
+matplotlib==3.3.4
+numpy==1.19.5
+pandas==1.1.1
+scikit-learn==0.24.2
+pyhive==0.7.0
+thrift==0.21.0
+thrift-sasl==0.4.3