data_loader_mysql.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/10/31
  5. @desc:
  6. """
  7. import pandas as pd
  8. import pymysql
  9. from commom.logger import get_logger
  10. from entitys import DbConfigEntity
  11. from .data_loader_base import DataLoaderBase
  12. logger = get_logger()
  13. class DataLoaderMysql(DataLoaderBase):
  14. def __init__(self, db_config: DbConfigEntity):
  15. self.db_config = db_config
  16. self.conn = None
  17. def get_connect(self):
  18. # TODO 后续改成线程池
  19. if self.conn == None:
  20. self.conn = pymysql.connect(host=self.db_config.host, port=self.db_config.port, user=self.db_config.user,
  21. passwd=self.db_config.passwd, db=self.db_config.db)
  22. return self.conn
  23. def close_connect(self):
  24. if self.conn != None:
  25. try:
  26. self.conn.close()
  27. except Exception as msg:
  28. logger.error("关闭数据库失败:\n" + str(msg))
  29. self.conn = None
  30. def get_data(self, sql: str) -> pd.DataFrame:
  31. cursor = self.get_connect().cursor()
  32. cursor.execute(sql)
  33. sql_results = cursor.fetchall()
  34. column_desc = cursor.description
  35. # 获取列名
  36. columns = [column_desc[i][0] for i in range(len(column_desc))]
  37. # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
  38. df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
  39. cursor.close()
  40. self.close_connect()
  41. return df