Browse Source

add: DataLoaderHive

yq 4 months ago
parent
commit
94a4fa69d8

+ 1 - 1
config/data_process_config_template.json

@@ -3,7 +3,7 @@
   "bin_search_interval": 0.05,
   "bin_search_interval": 0.05,
   "feature_search_strategy": "iv",
   "feature_search_strategy": "iv",
   "x_candidate_num": 10,
   "x_candidate_num": 10,
-  "special_values": null,
+  "special_values": {"age_in_years": [36]},
   "breaks_list": {
   "breaks_list": {
     "duration_in_month": [13, 17,  47],
     "duration_in_month": [13, 17,  47],
     "credit_amount": [2001, 3000, 4000, 5000,  10000],
     "credit_amount": [2001, 3000, 4000, 5000,  10000],

+ 7 - 0
config/hive_config.json

@@ -0,0 +1,7 @@
+{
+  "host": "101.126.81.2",
+  "port": 18004,
+  "user": null,
+  "passwd": null,
+  "db": "default"
+}

+ 2 - 1
data/__init__.py

@@ -7,9 +7,10 @@
 from .insight.data_explore import DataExplore
 from .insight.data_explore import DataExplore
 from .loader.data_loader_base import DataLoaderBase
 from .loader.data_loader_base import DataLoaderBase
 from .loader.data_loader_excel import DataLoaderExcel
 from .loader.data_loader_excel import DataLoaderExcel
+from .loader.data_loader_hive import DataLoaderHive
 from .loader.data_loader_mysql import DataLoaderMysql
 from .loader.data_loader_mysql import DataLoaderMysql
 
 
-__all__ = ['DataLoaderBase', 'DataLoaderMysql', 'DataLoaderExcel', 'DataExplore']
+__all__ = ['DataLoaderBase', 'DataLoaderMysql', 'DataLoaderHive', 'DataLoaderExcel', 'DataExplore']
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     pass
     pass

+ 47 - 0
data/loader/data_loader_hive.py

@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/31
+@desc: 
+"""
+import pandas as pd
+from pyhive import hive
+from commom import get_logger
+from entitys import DbConfigEntity
+from .data_loader_base import DataLoaderBase
+
+logger = get_logger()
+
+
+class DataLoaderHive(DataLoaderBase):
+    def __init__(self, db_config: DbConfigEntity):
+        self.db_config = db_config
+        self.conn = None
+
+    def get_connect(self):
+        #  TODO 后续改成线程池
+        if self.conn == None:
+            self.conn = hive.connect(host=self.db_config.host, port=self.db_config.port, auth=self.db_config.user,
+                                        password=self.db_config.passwd, database=self.db_config.db)
+        return self.conn
+
+    def close_connect(self):
+        if self.conn != None:
+            try:
+                self.conn.close()
+            except Exception as msg:
+                logger.error("关闭数据库失败:\n" + str(msg))
+            self.conn = None
+
+    def get_data(self, sql: str) -> pd.DataFrame:
+        cursor = self.get_connect().cursor()
+        cursor.execute(sql)
+        sql_results = cursor.fetchall()
+        column_desc = cursor.description
+        # 获取列名
+        columns = [column_desc[i][0] for i in range(len(column_desc))]
+        # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
+        df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
+        cursor.close()
+        self.close_connect()
+        return df

+ 10 - 5
metric_test.py

@@ -6,7 +6,7 @@
 """
 """
 import pandas as pd
 import pandas as pd
 
 
-from data import DataLoaderMysql, DataLoaderBase
+from data import DataLoaderMysql, DataLoaderBase, DataLoaderHive
 from entitys import DbConfigEntity, MetricFucEntity
 from entitys import DbConfigEntity, MetricFucEntity
 from metrics import MetricBase
 from metrics import MetricBase
 from monitor import MonitorMetric
 from monitor import MonitorMetric
@@ -30,7 +30,12 @@ class A(MetricBase):
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     # f_register_metric_func(A)
     # f_register_metric_func(A)
-    data_loader = DataLoaderMysql(DbConfigEntity.from_config("./config/mysql_config.json"))
-    monitor_metric = MonitorMetric("./config/model_monitor_config_template.json")
-    monitor_metric.calculate_metric(data_loader=data_loader)
-    monitor_metric.generate_report()
+
+    data_loader = DataLoaderHive(DbConfigEntity.from_config("./config/hive_config.json"))
+    df = data_loader.get_data("select * from pokes")
+    print(df.head())
+
+    # data_loader = DataLoaderMysql(DbConfigEntity.from_config("./config/mysql_config.json"))
+    # monitor_metric = MonitorMetric("./config/model_monitor_config_template.json")
+    # monitor_metric.calculate_metric(data_loader=data_loader)
+    # monitor_metric.generate_report()

+ 4 - 0
requirements-py310.txt

@@ -6,4 +6,8 @@ dataframe_image==0.1.14
 gradio==5.8.0
 gradio==5.8.0
 matplotlib==3.9.3
 matplotlib==3.9.3
 numpy==1.26.4
 numpy==1.26.4
+pandas==1.5.3
 scikit-learn==1.1.3
 scikit-learn==1.1.3
+pyhive==0.7.0
+thrift==0.21.0
+thrift-sasl==0.4.3

+ 8 - 1
requirements.txt

@@ -2,4 +2,11 @@ pymysql==1.0.2
 python-docx==0.8.11
 python-docx==0.8.11
 xlrd==1.2.0
 xlrd==1.2.0
 scorecardpy==0.1.9.7
 scorecardpy==0.1.9.7
-dataframe_image==0.1.14
+dataframe_image==0.1.14
+matplotlib==3.3.4
+numpy==1.19.5
+pandas==1.1.1
+scikit-learn==0.24.2
+pyhive==0.7.0
+thrift==0.21.0
+thrift-sasl==0.4.3