Переглянути джерело

modify: df转图片输出优化

yq 3 місяців тому
батько
коміт
aa9e2340a8

+ 323 - 0
commom/matplotlib_table.py

@@ -0,0 +1,323 @@
+import base64
+import io
+import textwrap
+
+import numpy as np
+from bs4 import BeautifulSoup
+from matplotlib import lines as mlines
+from matplotlib import patches as mpatches
+from matplotlib.backends.backend_agg import RendererAgg
+from matplotlib.figure import Figure
+from matplotlib.transforms import Bbox
+
+
+class TableMaker:
+    def __init__(
+            self,
+            fontsize=14,
+            encode_base64=True,
+            limit_crop=True,
+            for_document=True,
+            savefig_dpi=None,
+    ):
+        self.original_fontsize = fontsize
+        self.encode_base64 = encode_base64
+        self.limit_crop = limit_crop
+        self.for_document = for_document
+        self.figwidth = 1
+        self.figheight = 1
+        self.wrap_length = 30
+        if self.for_document:
+            self.figwidth = 20
+            self.figheight = 4
+            self.wrap_length = 10
+        self.dpi = 100
+        self.savefig_dpi = savefig_dpi
+
+    def parse_html(self, html):
+        html = html.replace("<br>", "\n")
+        rows, num_header_rows = self.parse_into_rows(html)
+        num_cols = sum(val[-1] for val in rows[0])
+        new_rows = []
+        rowspan = {}
+        # deal with muti-row or multi col cells
+        for i, row in enumerate(rows):
+            new_row = []
+            col_loc = 0
+            j = 0
+            while col_loc < num_cols:
+                if col_loc in rowspan:
+                    val = rowspan[col_loc]
+                    val[-2] -= 1
+                    cur_col_loc = 0
+                    for _ in range(val[-1]):
+                        cur_col_loc += 1
+                        new_row.append(val[:5])
+                    if val[-2] == 1:
+                        del rowspan[col_loc]
+                    col_loc += cur_col_loc
+                else:
+                    val = row[j]
+                    if val[-2] > 1:  # new rowspan detected
+                        rowspan[col_loc] = val
+                    col_loc += val[-1]  # usually 1
+                    for _ in range(val[-1]):
+                        new_row.append(val[:5])
+                    j += 1
+            new_rows.append(new_row)
+        return new_rows, num_header_rows
+
+    def get_text_align(self, element):
+        style = element.get("style", "").lower()
+        if "text-align" in style:
+            idx = style.find("text-align")
+            text_align = style[idx + 10:].split(":")[1].strip()
+            for val in ("left", "right", "center"):
+                if text_align.startswith(val):
+                    return val
+
+    def parse_into_rows(self, html):
+        def get_property(class_name, property_name):
+            for rule in sheet:
+                selectors = rule.selectorText.replace(" ", "").split(",")
+                if class_name in selectors:
+                    for style_property in rule.style:
+                        if style_property.name == property_name:
+                            return style_property.value
+
+        def parse_row(row):
+            values = []
+            rowspan_dict = {}
+            colspan_total = 0
+            row_align = self.get_text_align(row)
+            for el in row.find_all(["td", "th"]):
+                bold = el.name == "th"
+                colspan = int(el.attrs.get("colspan", 1))
+                rowspan = int(el.attrs.get("rowspan", 1))
+                text_align = self.get_text_align(el) or row_align
+                text = el.get_text()
+                if "id" in el.attrs:
+                    values.append(
+                        [
+                            text,
+                            bold,
+                            text_align,
+                            get_property("#" + el.attrs["id"], "background-color"),
+                            get_property("#" + el.attrs["id"], "color"),
+                            rowspan,
+                            colspan,
+                        ]
+                    )
+                else:
+                    values.append(
+                        [text, bold, text_align, "#ffffff", "#000000", rowspan, colspan]
+                    )
+            return values
+
+        soup = BeautifulSoup(html, features="lxml")
+        style = soup.find("style")
+        # if style:
+        #     sheet = cssutils.parseString(style.text)
+        # else:
+        #     sheet = []
+        # get number of columns from first row
+        # num_cols = sum(int(el.get('colspan', 1)) for el in soup.find('tr').find_all(['td', 'th']))
+        thead = soup.find("thead")
+        tbody = soup.find("tbody")
+
+        rows = []
+        if thead:
+            head_rows = thead.find_all("tr")
+            if head_rows:
+                for row in head_rows:
+                    rows.append(parse_row(row))
+            else:
+                rows.append(parse_row(thead))
+
+        num_header_rows = len(rows)
+
+        if tbody:
+            for row in tbody.find_all("tr"):
+                rows.append(parse_row(row))
+
+        if not thead and not tbody:
+            for row in soup.find_all("tr"):
+                rows.append(parse_row(row))
+        return rows, num_header_rows
+
+    def get_text_width(self, text, weight=None):
+        fig = self.text_fig
+        t = fig.text(0, 0, text, size=self.fontsize, weight=weight)
+        bbox = t.get_window_extent(renderer=self.renderer)
+        return bbox.width
+
+    def get_all_text_widths(self, rows):
+        all_text_widths = []
+        for i, row in enumerate(rows):
+            row_widths = []
+            for vals in row:
+                cell_max_width = 0
+                for text in vals[0].split("\n"):
+                    weight = "bold" if i == 0 else None
+                    cell_max_width = max(
+                        cell_max_width, self.get_text_width(text, weight)
+                    )
+                row_widths.append(cell_max_width)
+            all_text_widths.append(row_widths)
+        pad = 10  # number of pixels to pad columns with
+        return np.array(all_text_widths) + 15
+
+    def calculate_col_widths(self):
+        all_text_widths = self.get_all_text_widths(self.rows)
+        max_col_widths = all_text_widths.max(axis=0)
+        mult = 1
+        total_width = self.figwidth * self.dpi
+        if self.for_document and sum(max_col_widths) >= total_width:
+            while mult > 0.5:
+                mult *= 0.9
+                for idx in np.argsort(-max_col_widths):
+                    col_widths = all_text_widths[:, idx]
+                    rows = self.wrap_col(idx, col_widths, mult)
+                    all_text_widths = self.get_all_text_widths(rows)
+                    max_col_widths = all_text_widths.max(axis=0)
+                    if sum(max_col_widths) < total_width:
+                        break
+
+            if mult <= 0.5 and self.fontsize > 12:
+                self.fontsize *= 0.9
+                return self.calculate_col_widths()
+            else:
+                self.rows = rows
+                total_width = sum(max_col_widths)
+
+        col_prop = [width / total_width for width in max_col_widths]
+        return col_prop
+
+    def wrap_col(self, idx, col_widths, mult):
+        rows = self.rows.copy()
+        max_width = max(col_widths)
+        texts = [row[idx][0] for row in self.rows]
+        new_texts = []
+        new_max_width = 0
+        for i, (text, col_width) in enumerate(zip(texts, col_widths)):
+            if col_width > mult * max_width and len(text) > self.wrap_length:
+                width = max(self.wrap_length, int(len(text) * mult))
+                new_text = textwrap.fill(text, width, break_long_words=False)
+                new_texts.append(new_text)
+                new_max_width = max(new_max_width, self.get_text_width(new_text))
+            else:
+                new_texts.append(text)
+
+        if new_max_width < max_width:
+            for row, text in zip(rows, new_texts):
+                row[idx][0] = text
+        return rows
+
+    def get_row_heights(self):
+        row_heights = []
+        for row in self.rows:
+            row_count = max([val[0].count("\n") + 1 for val in row])
+            height = (row_count + 1) * self.fontsize / 72
+            row_heights.append(height)
+
+        return row_heights
+
+    def create_figure(self):
+        figheight = sum(self.row_heights)
+        fig = Figure(dpi=self.dpi, figsize=(self.figwidth, figheight))
+        return fig
+
+    def print_table(self):
+        row_colors = ["#f5f5f5", "#ffffff"]
+        # padding 0.5 em
+        padding = self.fontsize / (self.figwidth * self.dpi) * 0.5
+        total_width = sum(self.col_widths)
+        figheight = self.fig.get_figheight()
+        row_locs = [height / figheight for height in self.row_heights]
+
+        header_text_align = [vals[2] for vals in self.rows[0]]
+        x0 = (1 - total_width) / 2
+        x = x0
+        yd = row_locs[0]
+        y = 1
+
+        for i, (yd, row) in enumerate(zip(row_locs, self.rows)):
+            x = x0
+            y -= yd
+            # table zebra stripes
+            diff = i - self.num_header_rows
+            if diff >= 0 and diff % 2 == 0:
+                p = mpatches.Rectangle(
+                    (x0, y),
+                    width=total_width,
+                    height=yd,
+                    fill=True,
+                    color=row_colors[0],
+                    transform=self.fig.transFigure,
+                )
+                self.fig.add_artist(p)
+            for j, (xd, val) in enumerate(zip(self.col_widths, row)):
+                text = val[0]
+                weight = "bold" if val[1] else None
+                ha = val[2] or header_text_align[j] or "right"
+                fg = val[4] if val[4] else "#000000"
+                bg = val[3] if val[3] else None
+
+                if bg:
+                    rect_bg = mpatches.Rectangle(
+                        (x, y),
+                        width=xd,
+                        height=yd,
+                        fill=True,
+                        color=bg,
+                        transform=self.fig.transFigure,
+                    )
+                    self.fig.add_artist(rect_bg)
+
+                if ha == "right":
+                    x_pos = x + xd - padding
+                elif ha == "center":
+                    x_pos = x + xd / 2
+                elif ha == "left":
+                    x_pos = x + padding
+
+                self.fig.text(
+                    x_pos,
+                    y + yd / 2,
+                    text,
+                    size=self.fontsize,
+                    ha=ha,
+                    va="center",
+                    weight=weight,
+                    color=fg,
+                    # backgroundcolor=bg
+                )
+                x += xd
+
+            if i == self.num_header_rows - 1:
+                line = mlines.Line2D([x0, x0 + total_width], [y, y], color="black")
+                self.fig.add_artist(line)
+
+        w, h = self.fig.get_size_inches()
+        start = self.figwidth * min(x0, 0.1)
+        end = self.figwidth - start
+        bbox = Bbox([[start - 0.1, y * h], [end + 0.1, h]])
+        buffer = io.BytesIO()
+        self.fig.savefig(buffer, bbox_inches=bbox, dpi=self.savefig_dpi)
+        img_str = buffer.getvalue()
+        if self.encode_base64:
+            img_str = base64.b64encode(img_str).decode()
+        return img_str
+
+    def run(self, df, filename: str):
+        html = df.to_html(notebook=False)  # notebook 控制单元内容超长后是否截断
+        self.fontsize = self.original_fontsize
+        self.text_fig = Figure(dpi=self.dpi)
+        self.renderer = RendererAgg(self.figwidth, self.figheight, self.dpi)
+        self.rows, self.num_header_rows = self.parse_html(html)
+        self.col_widths = self.calculate_col_widths()
+        self.row_heights = self.get_row_heights()
+        self.fig = self.create_figure()
+        img_str = self.print_table()
+        with open(filename, "wb") as f:
+            f.write(img_str)

+ 36 - 3
commom/utils.py

@@ -12,11 +12,11 @@ import os
 from json import JSONEncoder
 from typing import Union
 
-import dataframe_image as dfi
 import pandas as pd
 import pytz
 
 from config import BaseConfig
+from .matplotlib_table import TableMaker
 
 
 def f_format_float(num: float, n=3):
@@ -49,8 +49,41 @@ def f_save_train_df(file_name: str, df: pd.DataFrame):
     df.to_excel(f"{file_path}.xlsx", index=False)
 
 
-def f_df_to_image(df, filename, fontsize=12):
-    dfi.export(obj=df, filename=filename, fontsize=fontsize, table_conversion='matplotlib')
+def f_df_to_image(df: pd.DataFrame, filename, fontsize=12):
+    converter = TableMaker(fontsize=fontsize, encode_base64=False, for_document=False)
+    converter.run(df, filename)
+
+    # if importlib.util.find_spec("dataframe_image"):
+    #     import dataframe_image as dfi
+    #
+    #     dfi.export(obj=df, filename=filename, fontsize=fontsize, table_conversion='matplotlib')
+    # elif importlib.util.find_spec("plotly"):
+    #     import plotly.graph_objects as go
+    #     import plotly.figure_factory as ff
+    #     import plotly.io as pio
+    #
+    #     fig = ff.create_table(df)
+    #     fig.update_layout()
+    #     fig.write_image(filename)
+    #
+    #     fig = go.Figure(data=go.Table(
+    #         header=dict(
+    #             values=df.columns.to_list(),
+    #             font=dict(color='black', size=fontsize),
+    #             fill_color="white",
+    #             line_color='black',
+    #             align="center"
+    #         ),
+    #         cells=dict(
+    #             values=[df[k].tolist() for k in df.columns],
+    #             font=dict(color='black', size=fontsize),
+    #             fill_color="white",
+    #             line_color='black',
+    #             align="center")
+    #     )).update_layout()
+    #     pio.write_image(fig, filename)
+    # else:
+    #     raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"缺少画图依赖【dataframe_image】或者【plotly】")
 
 
 def _f_image_to_base64(image_path):

+ 6 - 2
data/loader/data_loader_hive.py

@@ -4,8 +4,12 @@
 @time: 2024/10/31
 @desc: 
 """
+import importlib.util
+
+if importlib.util.find_spec("hive"):
+    from pyhive import hive
+
 import pandas as pd
-from pyhive import hive
 from commom import get_logger
 from entitys import DbConfigEntity
 from .data_loader_base import DataLoaderBase
@@ -22,7 +26,7 @@ class DataLoaderHive(DataLoaderBase):
         #  TODO 后续改成线程池
         if self.conn == None:
             self.conn = hive.connect(host=self.db_config.host, port=self.db_config.port, auth=self.db_config.user,
-                                        password=self.db_config.passwd, database=self.db_config.db)
+                                     password=self.db_config.passwd, database=self.db_config.db)
         return self.conn
 
     def close_connect(self):

+ 5 - 2
data/loader/data_loader_mysql.py

@@ -4,9 +4,12 @@
 @time: 2024/10/31
 @desc: 
 """
-import pandas as pd
-import pymysql
+import importlib.util
+
+if importlib.util.find_spec("pymysql"):
+    import pymysql
 
+import pandas as pd
 from commom import get_logger
 from entitys import DbConfigEntity
 from .data_loader_base import DataLoaderBase

+ 19 - 0
requirements-analysis.txt

@@ -0,0 +1,19 @@
+#pymysql==1.0.2
+python-docx==0.8.11
+xlrd==2.0.1
+scorecardpy==0.1.9.2
+#dataframe_image==0.1.14
+matplotlib==3.3.4
+numpy==1.18.2
+pandas==1.1.5
+scikit-learn==0.24.2
+#pyhive==0.7.0
+thrift==0.16.0
+#thrift-sasl==0.4.3
+seaborn==0.11.2
+contextvars==2.4
+tqdm==4.64.0
+plotly==5.10.0
+kaleido==0.2.1
+statsmodels==0.12.2
+beautifulsoup4==4.11.1