import base64 import io import textwrap import numpy as np from bs4 import BeautifulSoup from matplotlib import lines as mlines from matplotlib import patches as mpatches from matplotlib.backends.backend_agg import RendererAgg from matplotlib.figure import Figure from matplotlib.transforms import Bbox class TableMaker: def __init__( self, fontsize=14, encode_base64=True, limit_crop=True, for_document=True, savefig_dpi=None, ): self.original_fontsize = fontsize self.encode_base64 = encode_base64 self.limit_crop = limit_crop self.for_document = for_document self.figwidth = 1 self.figheight = 1 self.wrap_length = 30 if self.for_document: self.figwidth = 20 self.figheight = 4 self.wrap_length = 10 self.dpi = 100 self.savefig_dpi = savefig_dpi def parse_html(self, html): html = html.replace("
", "\n") rows, num_header_rows = self.parse_into_rows(html) num_cols = sum(val[-1] for val in rows[0]) new_rows = [] rowspan = {} # deal with muti-row or multi col cells for i, row in enumerate(rows): new_row = [] col_loc = 0 j = 0 while col_loc < num_cols: if col_loc in rowspan: val = rowspan[col_loc] val[-2] -= 1 cur_col_loc = 0 for _ in range(val[-1]): cur_col_loc += 1 new_row.append(val[:5]) if val[-2] == 1: del rowspan[col_loc] col_loc += cur_col_loc else: val = row[j] if val[-2] > 1: # new rowspan detected rowspan[col_loc] = val col_loc += val[-1] # usually 1 for _ in range(val[-1]): new_row.append(val[:5]) j += 1 new_rows.append(new_row) return new_rows, num_header_rows def get_text_align(self, element): style = element.get("style", "").lower() if "text-align" in style: idx = style.find("text-align") text_align = style[idx + 10:].split(":")[1].strip() for val in ("left", "right", "center"): if text_align.startswith(val): return val def parse_into_rows(self, html): def get_property(class_name, property_name): for rule in sheet: selectors = rule.selectorText.replace(" ", "").split(",") if class_name in selectors: for style_property in rule.style: if style_property.name == property_name: return style_property.value def parse_row(row): values = [] rowspan_dict = {} colspan_total = 0 row_align = self.get_text_align(row) for el in row.find_all(["td", "th"]): bold = el.name == "th" colspan = int(el.attrs.get("colspan", 1)) rowspan = int(el.attrs.get("rowspan", 1)) text_align = self.get_text_align(el) or row_align text = el.get_text() if "id" in el.attrs: values.append( [ text, bold, text_align, get_property("#" + el.attrs["id"], "background-color"), get_property("#" + el.attrs["id"], "color"), rowspan, colspan, ] ) else: values.append( [text, bold, text_align, "#ffffff", "#000000", rowspan, colspan] ) return values soup = BeautifulSoup(html, features="lxml") style = soup.find("style") # if style: # sheet = cssutils.parseString(style.text) # else: # sheet = [] # get number of columns from first row # num_cols = sum(int(el.get('colspan', 1)) for el in soup.find('tr').find_all(['td', 'th'])) thead = soup.find("thead") tbody = soup.find("tbody") rows = [] if thead: head_rows = thead.find_all("tr") if head_rows: for row in head_rows: rows.append(parse_row(row)) else: rows.append(parse_row(thead)) num_header_rows = len(rows) if tbody: for row in tbody.find_all("tr"): rows.append(parse_row(row)) if not thead and not tbody: for row in soup.find_all("tr"): rows.append(parse_row(row)) return rows, num_header_rows def get_text_width(self, text, weight=None): fig = self.text_fig t = fig.text(0, 0, text, size=self.fontsize, weight=weight) bbox = t.get_window_extent(renderer=self.renderer) return bbox.width def get_all_text_widths(self, rows): all_text_widths = [] for i, row in enumerate(rows): row_widths = [] for vals in row: cell_max_width = 0 for text in vals[0].split("\n"): weight = "bold" if i == 0 else None cell_max_width = max( cell_max_width, self.get_text_width(text, weight) ) row_widths.append(cell_max_width) all_text_widths.append(row_widths) pad = 10 # number of pixels to pad columns with return np.array(all_text_widths) + 15 def calculate_col_widths(self): all_text_widths = self.get_all_text_widths(self.rows) max_col_widths = all_text_widths.max(axis=0) mult = 1 total_width = self.figwidth * self.dpi if self.for_document and sum(max_col_widths) >= total_width: while mult > 0.5: mult *= 0.9 for idx in np.argsort(-max_col_widths): col_widths = all_text_widths[:, idx] rows = self.wrap_col(idx, col_widths, mult) all_text_widths = self.get_all_text_widths(rows) max_col_widths = all_text_widths.max(axis=0) if sum(max_col_widths) < total_width: break if mult <= 0.5 and self.fontsize > 12: self.fontsize *= 0.9 return self.calculate_col_widths() else: self.rows = rows total_width = sum(max_col_widths) col_prop = [width / total_width for width in max_col_widths] return col_prop def wrap_col(self, idx, col_widths, mult): rows = self.rows.copy() max_width = max(col_widths) texts = [row[idx][0] for row in self.rows] new_texts = [] new_max_width = 0 for i, (text, col_width) in enumerate(zip(texts, col_widths)): if col_width > mult * max_width and len(text) > self.wrap_length: width = max(self.wrap_length, int(len(text) * mult)) new_text = textwrap.fill(text, width, break_long_words=False) new_texts.append(new_text) new_max_width = max(new_max_width, self.get_text_width(new_text)) else: new_texts.append(text) if new_max_width < max_width: for row, text in zip(rows, new_texts): row[idx][0] = text return rows def get_row_heights(self): row_heights = [] for row in self.rows: row_count = max([val[0].count("\n") + 1 for val in row]) height = (row_count + 1) * self.fontsize / 72 row_heights.append(height) return row_heights def create_figure(self): figheight = sum(self.row_heights) fig = Figure(dpi=self.dpi, figsize=(self.figwidth, figheight)) return fig def print_table(self): row_colors = ["#f5f5f5", "#ffffff"] # padding 0.5 em padding = self.fontsize / (self.figwidth * self.dpi) * 0.5 total_width = sum(self.col_widths) figheight = self.fig.get_figheight() row_locs = [height / figheight for height in self.row_heights] header_text_align = [vals[2] for vals in self.rows[0]] x0 = (1 - total_width) / 2 x = x0 yd = row_locs[0] y = 1 for i, (yd, row) in enumerate(zip(row_locs, self.rows)): x = x0 y -= yd # table zebra stripes diff = i - self.num_header_rows if diff >= 0 and diff % 2 == 0: p = mpatches.Rectangle( (x0, y), width=total_width, height=yd, fill=True, color=row_colors[0], transform=self.fig.transFigure, ) self.fig.add_artist(p) for j, (xd, val) in enumerate(zip(self.col_widths, row)): text = val[0] weight = "bold" if val[1] else None ha = val[2] or header_text_align[j] or "right" fg = val[4] if val[4] else "#000000" bg = val[3] if val[3] else None if bg: rect_bg = mpatches.Rectangle( (x, y), width=xd, height=yd, fill=True, color=bg, transform=self.fig.transFigure, ) self.fig.add_artist(rect_bg) if ha == "right": x_pos = x + xd - padding elif ha == "center": x_pos = x + xd / 2 elif ha == "left": x_pos = x + padding self.fig.text( x_pos, y + yd / 2, text, size=self.fontsize, ha=ha, va="center", weight=weight, color=fg, # backgroundcolor=bg ) x += xd if i == self.num_header_rows - 1: line = mlines.Line2D([x0, x0 + total_width], [y, y], color="black") self.fig.add_artist(line) w, h = self.fig.get_size_inches() start = self.figwidth * min(x0, 0.1) end = self.figwidth - start bbox = Bbox([[start - 0.1, y * h], [end + 0.1, h]]) buffer = io.BytesIO() self.fig.savefig(buffer, bbox_inches=bbox, dpi=self.savefig_dpi) img_str = buffer.getvalue() if self.encode_base64: img_str = base64.b64encode(img_str).decode() return img_str def run(self, df, filename: str): html = df.to_html(notebook=False) # notebook 控制单元内容超长后是否截断 self.fontsize = self.original_fontsize self.text_fig = Figure(dpi=self.dpi) self.renderer = RendererAgg(self.figwidth, self.figheight, self.dpi) self.rows, self.num_header_rows = self.parse_html(html) self.col_widths = self.calculate_col_widths() self.row_heights = self.get_row_heights() self.fig = self.create_figure() img_str = self.print_table() with open(filename, "wb") as f: f.write(img_str)