12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/10/31
- @desc:
- """
- import pandas as pd
- import pymysql
- from commom.logger import get_logger
- from .data_loader_base import DataLoaderBase
- logger = get_logger()
- class DataLoaderMysql(DataLoaderBase):
- def __init__(self, host: str, port: int, user: str, passwd: str, db: str):
- self.host = host
- self.port = port
- self.user = user
- self.passwd = passwd
- self.db = db
- self.conn = None
- def get_connect(self):
- # TODO 后续改成线程池
- if self.conn == None:
- self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd,
- db=self.db)
- return self.conn
- def close_connect(self):
- if self.conn != None:
- try:
- self.conn.close()
- except Exception as msg:
- logger.error("关闭数据库失败:\n" + str(msg))
- self.conn = None
- def get_data(self, sql: str) -> pd.DataFrame:
- cursor = self.get_connect().cursor()
- cursor.execute(sql)
- sql_results = cursor.fetchall()
- column_desc = cursor.description
- # 获取列名
- columns = [column_desc[i][0] for i in range(len(column_desc))]
- # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
- df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
- cursor.close()
- self.close_connect()
- return df
- if __name__ == "__main__":
- pass
|