# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/10/31
@desc: 
"""
import importlib.util

if importlib.util.find_spec("pymysql"):
    import pymysql

import pandas as pd
from commom import get_logger
from entitys import DbConfigEntity
from .data_loader_base import DataLoaderBase

logger = get_logger()


class DataLoaderMysql(DataLoaderBase):
    def __init__(self, db_config: DbConfigEntity):
        self.db_config = db_config
        self.conn = None

    def get_connect(self):
        #  TODO 后续改成线程池
        if self.conn == None:
            self.conn = pymysql.connect(host=self.db_config.host, port=self.db_config.port, user=self.db_config.user,
                                        passwd=self.db_config.passwd, db=self.db_config.db)
        return self.conn

    def close_connect(self):
        if self.conn != None:
            try:
                self.conn.close()
            except Exception as msg:
                logger.error("关闭数据库失败:\n" + str(msg))
            self.conn = None

    def get_data(self, sql: str) -> pd.DataFrame:
        cursor = self.get_connect().cursor()
        cursor.execute(sql)
        sql_results = cursor.fetchall()
        column_desc = cursor.description
        # 获取列名
        columns = [column_desc[i][0] for i in range(len(column_desc))]
        # 得到的data为二维元组,逐行取出,转化为列表,再转化为df
        df = pd.DataFrame([list(i) for i in sql_results], columns=columns)
        cursor.close()
        self.close_connect()
        return df