data_loader_mysql.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/10/31
  5. @desc:
  6. """
  7. import pandas as pd
  8. import pymysql
  9. from commom.logger import get_logger
  10. from .data_loader_base import DataLoaderBase
  11. logger = get_logger()
  12. class DataLoaderMysql(DataLoaderBase):
  13. def __init__(self, host: str, port: int, user: str, passwd: str, db: str):
  14. self.host = host
  15. self.port = port
  16. self.user = user
  17. self.passwd = passwd
  18. self.db = db
  19. self.conn = None
  20. def get_connect(self):
  21. # TODO 后续改成线程池
  22. if self.conn == None:
  23. self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd,
  24. db=self.db)
  25. return self.conn
  26. def close_connect(self):
  27. if self.conn != None:
  28. try:
  29. self.conn.close()
  30. except Exception as msg:
  31. logger.error("关闭数据库失败:\n" + str(msg))
  32. self.conn = None
  33. def get_data(self, sql: str) -> pd.DataFrame:
  34. cursor = self.get_connect().cursor()
  35. cursor.execute(sql)
  36. sql_results = cursor.fetchall()
  37. column_desc = cursor.description
  38. # 获取列名
  39. columns = [column_desc[i][0] for i in range(len(column_desc))]
  40. # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
  41. df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
  42. cursor.close()
  43. self.close_connect()
  44. return df
  45. if __name__ == "__main__":
  46. pass