123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/10/31
- @desc:
- """
- import importlib.util
- if importlib.util.find_spec("pymysql"):
- import pymysql
- import pandas as pd
- from commom import get_logger
- from entitys import DbConfigEntity
- from .data_loader_base import DataLoaderBase
- logger = get_logger()
- class DataLoaderMysql(DataLoaderBase):
- def __init__(self, db_config: DbConfigEntity):
- self.db_config = db_config
- self.conn = None
- def get_connect(self):
- # TODO 后续改成线程池
- if self.conn == None:
- self.conn = pymysql.connect(host=self.db_config.host, port=self.db_config.port, user=self.db_config.user,
- passwd=self.db_config.passwd, db=self.db_config.db)
- return self.conn
- def close_connect(self):
- if self.conn != None:
- try:
- self.conn.close()
- except Exception as msg:
- logger.error("关闭数据库失败:\n" + str(msg))
- self.conn = None
- def get_data(self, sql: str) -> pd.DataFrame:
- cursor = self.get_connect().cursor()
- cursor.execute(sql)
- sql_results = cursor.fetchall()
- column_desc = cursor.description
- # 获取列名
- columns = [column_desc[i][0] for i in range(len(column_desc))]
- # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
- df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
- cursor.close()
- self.close_connect()
- return df
|