123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- import base64
- import io
- import textwrap
- import numpy as np
- from bs4 import BeautifulSoup
- from matplotlib import lines as mlines
- from matplotlib import patches as mpatches
- from matplotlib.backends.backend_agg import RendererAgg
- from matplotlib.figure import Figure
- from matplotlib.transforms import Bbox
- class TableMaker:
- def __init__(
- self,
- fontsize=14,
- encode_base64=True,
- limit_crop=True,
- for_document=True,
- savefig_dpi=None,
- ):
- self.original_fontsize = fontsize
- self.encode_base64 = encode_base64
- self.limit_crop = limit_crop
- self.for_document = for_document
- self.figwidth = 1
- self.figheight = 1
- self.wrap_length = 30
- if self.for_document:
- self.figwidth = 20
- self.figheight = 4
- self.wrap_length = 10
- self.dpi = 100
- self.savefig_dpi = savefig_dpi
- def parse_html(self, html):
- html = html.replace("<br>", "\n")
- rows, num_header_rows = self.parse_into_rows(html)
- num_cols = sum(val[-1] for val in rows[0])
- new_rows = []
- rowspan = {}
- # deal with muti-row or multi col cells
- for i, row in enumerate(rows):
- new_row = []
- col_loc = 0
- j = 0
- while col_loc < num_cols:
- if col_loc in rowspan:
- val = rowspan[col_loc]
- val[-2] -= 1
- cur_col_loc = 0
- for _ in range(val[-1]):
- cur_col_loc += 1
- new_row.append(val[:5])
- if val[-2] == 1:
- del rowspan[col_loc]
- col_loc += cur_col_loc
- else:
- val = row[j]
- if val[-2] > 1: # new rowspan detected
- rowspan[col_loc] = val
- col_loc += val[-1] # usually 1
- for _ in range(val[-1]):
- new_row.append(val[:5])
- j += 1
- new_rows.append(new_row)
- return new_rows, num_header_rows
- def get_text_align(self, element):
- style = element.get("style", "").lower()
- if "text-align" in style:
- idx = style.find("text-align")
- text_align = style[idx + 10:].split(":")[1].strip()
- for val in ("left", "right", "center"):
- if text_align.startswith(val):
- return val
- def parse_into_rows(self, html):
- def get_property(class_name, property_name):
- for rule in sheet:
- selectors = rule.selectorText.replace(" ", "").split(",")
- if class_name in selectors:
- for style_property in rule.style:
- if style_property.name == property_name:
- return style_property.value
- def parse_row(row):
- values = []
- rowspan_dict = {}
- colspan_total = 0
- row_align = self.get_text_align(row)
- for el in row.find_all(["td", "th"]):
- bold = el.name == "th"
- colspan = int(el.attrs.get("colspan", 1))
- rowspan = int(el.attrs.get("rowspan", 1))
- text_align = self.get_text_align(el) or row_align
- text = el.get_text()
- if "id" in el.attrs:
- values.append(
- [
- text,
- bold,
- text_align,
- get_property("#" + el.attrs["id"], "background-color"),
- get_property("#" + el.attrs["id"], "color"),
- rowspan,
- colspan,
- ]
- )
- else:
- values.append(
- [text, bold, text_align, "#ffffff", "#000000", rowspan, colspan]
- )
- return values
- soup = BeautifulSoup(html, features="lxml")
- style = soup.find("style")
- # if style:
- # sheet = cssutils.parseString(style.text)
- # else:
- # sheet = []
- # get number of columns from first row
- # num_cols = sum(int(el.get('colspan', 1)) for el in soup.find('tr').find_all(['td', 'th']))
- thead = soup.find("thead")
- tbody = soup.find("tbody")
- rows = []
- if thead:
- head_rows = thead.find_all("tr")
- if head_rows:
- for row in head_rows:
- rows.append(parse_row(row))
- else:
- rows.append(parse_row(thead))
- num_header_rows = len(rows)
- if tbody:
- for row in tbody.find_all("tr"):
- rows.append(parse_row(row))
- if not thead and not tbody:
- for row in soup.find_all("tr"):
- rows.append(parse_row(row))
- return rows, num_header_rows
- def get_text_width(self, text, weight=None):
- fig = self.text_fig
- t = fig.text(0, 0, text, size=self.fontsize, weight=weight)
- bbox = t.get_window_extent(renderer=self.renderer)
- return bbox.width
- def get_all_text_widths(self, rows):
- all_text_widths = []
- for i, row in enumerate(rows):
- row_widths = []
- for vals in row:
- cell_max_width = 0
- for text in vals[0].split("\n"):
- weight = "bold" if i == 0 else None
- cell_max_width = max(
- cell_max_width, self.get_text_width(text, weight)
- )
- row_widths.append(cell_max_width)
- all_text_widths.append(row_widths)
- pad = 10 # number of pixels to pad columns with
- return np.array(all_text_widths) + 15
- def calculate_col_widths(self):
- all_text_widths = self.get_all_text_widths(self.rows)
- max_col_widths = all_text_widths.max(axis=0)
- mult = 1
- total_width = self.figwidth * self.dpi
- if self.for_document and sum(max_col_widths) >= total_width:
- while mult > 0.5:
- mult *= 0.9
- for idx in np.argsort(-max_col_widths):
- col_widths = all_text_widths[:, idx]
- rows = self.wrap_col(idx, col_widths, mult)
- all_text_widths = self.get_all_text_widths(rows)
- max_col_widths = all_text_widths.max(axis=0)
- if sum(max_col_widths) < total_width:
- break
- if mult <= 0.5 and self.fontsize > 12:
- self.fontsize *= 0.9
- return self.calculate_col_widths()
- else:
- self.rows = rows
- total_width = sum(max_col_widths)
- col_prop = [width / total_width for width in max_col_widths]
- return col_prop
- def wrap_col(self, idx, col_widths, mult):
- rows = self.rows.copy()
- max_width = max(col_widths)
- texts = [row[idx][0] for row in self.rows]
- new_texts = []
- new_max_width = 0
- for i, (text, col_width) in enumerate(zip(texts, col_widths)):
- if col_width > mult * max_width and len(text) > self.wrap_length:
- width = max(self.wrap_length, int(len(text) * mult))
- new_text = textwrap.fill(text, width, break_long_words=False)
- new_texts.append(new_text)
- new_max_width = max(new_max_width, self.get_text_width(new_text))
- else:
- new_texts.append(text)
- if new_max_width < max_width:
- for row, text in zip(rows, new_texts):
- row[idx][0] = text
- return rows
- def get_row_heights(self):
- row_heights = []
- for row in self.rows:
- row_count = max([val[0].count("\n") + 1 for val in row])
- height = (row_count + 1) * self.fontsize / 72
- row_heights.append(height)
- return row_heights
- def create_figure(self):
- figheight = sum(self.row_heights)
- fig = Figure(dpi=self.dpi, figsize=(self.figwidth, figheight))
- return fig
- def print_table(self):
- row_colors = ["#f5f5f5", "#ffffff"]
- # padding 0.5 em
- padding = self.fontsize / (self.figwidth * self.dpi) * 0.5
- total_width = sum(self.col_widths)
- figheight = self.fig.get_figheight()
- row_locs = [height / figheight for height in self.row_heights]
- header_text_align = [vals[2] for vals in self.rows[0]]
- x0 = (1 - total_width) / 2
- x = x0
- yd = row_locs[0]
- y = 1
- for i, (yd, row) in enumerate(zip(row_locs, self.rows)):
- x = x0
- y -= yd
- # table zebra stripes
- diff = i - self.num_header_rows
- if diff >= 0 and diff % 2 == 0:
- p = mpatches.Rectangle(
- (x0, y),
- width=total_width,
- height=yd,
- fill=True,
- color=row_colors[0],
- transform=self.fig.transFigure,
- )
- self.fig.add_artist(p)
- for j, (xd, val) in enumerate(zip(self.col_widths, row)):
- text = val[0]
- weight = "bold" if val[1] else None
- ha = val[2] or header_text_align[j] or "right"
- fg = val[4] if val[4] else "#000000"
- bg = val[3] if val[3] else None
- if bg:
- rect_bg = mpatches.Rectangle(
- (x, y),
- width=xd,
- height=yd,
- fill=True,
- color=bg,
- transform=self.fig.transFigure,
- )
- self.fig.add_artist(rect_bg)
- if ha == "right":
- x_pos = x + xd - padding
- elif ha == "center":
- x_pos = x + xd / 2
- elif ha == "left":
- x_pos = x + padding
- self.fig.text(
- x_pos,
- y + yd / 2,
- text,
- size=self.fontsize,
- ha=ha,
- va="center",
- weight=weight,
- color=fg,
- # backgroundcolor=bg
- )
- x += xd
- if i == self.num_header_rows - 1:
- line = mlines.Line2D([x0, x0 + total_width], [y, y], color="black")
- self.fig.add_artist(line)
- w, h = self.fig.get_size_inches()
- start = self.figwidth * min(x0, 0.1)
- end = self.figwidth - start
- bbox = Bbox([[start - 0.1, y * h], [end + 0.1, h]])
- buffer = io.BytesIO()
- self.fig.savefig(buffer, bbox_inches=bbox, dpi=self.savefig_dpi)
- img_str = buffer.getvalue()
- if self.encode_base64:
- img_str = base64.b64encode(img_str).decode()
- return img_str
- def run(self, df, filename: str):
- html = df.to_html(notebook=False) # notebook 控制单元内容超长后是否截断
- self.fontsize = self.original_fontsize
- self.text_fig = Figure(dpi=self.dpi)
- self.renderer = RendererAgg(self.figwidth, self.figheight, self.dpi)
- self.rows, self.num_header_rows = self.parse_html(html)
- self.col_widths = self.calculate_col_widths()
- self.row_heights = self.get_row_heights()
- self.fig = self.create_figure()
- img_str = self.print_table()
- with open(filename, "wb") as f:
- f.write(img_str)
|