data_loader_mysql.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/10/31
  5. @desc:
  6. """
  7. import importlib.util
  8. if importlib.util.find_spec("pymysql"):
  9. import pymysql
  10. import pandas as pd
  11. from commom import get_logger
  12. from entitys import DbConfigEntity
  13. from .data_loader_base import DataLoaderBase
  14. logger = get_logger()
  15. class DataLoaderMysql(DataLoaderBase):
  16. def __init__(self, db_config: DbConfigEntity):
  17. self.db_config = db_config
  18. self.conn = None
  19. def get_connect(self):
  20. # TODO 后续改成线程池
  21. if self.conn == None:
  22. self.conn = pymysql.connect(host=self.db_config.host, port=self.db_config.port, user=self.db_config.user,
  23. passwd=self.db_config.passwd, db=self.db_config.db)
  24. return self.conn
  25. def close_connect(self):
  26. if self.conn != None:
  27. try:
  28. self.conn.close()
  29. except Exception as msg:
  30. logger.error("关闭数据库失败:\n" + str(msg))
  31. self.conn = None
  32. def get_data(self, sql: str) -> pd.DataFrame:
  33. cursor = self.get_connect().cursor()
  34. cursor.execute(sql)
  35. sql_results = cursor.fetchall()
  36. column_desc = cursor.description
  37. # 获取列名
  38. columns = [column_desc[i][0] for i in range(len(column_desc))]
  39. # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
  40. df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
  41. cursor.close()
  42. self.close_connect()
  43. return df