matplotlib_table.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import base64
  2. import io
  3. import textwrap
  4. import numpy as np
  5. from bs4 import BeautifulSoup
  6. from matplotlib import lines as mlines
  7. from matplotlib import patches as mpatches
  8. from matplotlib.backends.backend_agg import RendererAgg
  9. from matplotlib.figure import Figure
  10. from matplotlib.transforms import Bbox
  11. class TableMaker:
  12. def __init__(
  13. self,
  14. fontsize=14,
  15. encode_base64=True,
  16. limit_crop=True,
  17. for_document=True,
  18. savefig_dpi=None,
  19. ):
  20. self.original_fontsize = fontsize
  21. self.encode_base64 = encode_base64
  22. self.limit_crop = limit_crop
  23. self.for_document = for_document
  24. self.figwidth = 1
  25. self.figheight = 1
  26. self.wrap_length = 30
  27. if self.for_document:
  28. self.figwidth = 20
  29. self.figheight = 4
  30. self.wrap_length = 10
  31. self.dpi = 100
  32. self.savefig_dpi = savefig_dpi
  33. def parse_html(self, html):
  34. html = html.replace("<br>", "\n")
  35. rows, num_header_rows = self.parse_into_rows(html)
  36. num_cols = sum(val[-1] for val in rows[0])
  37. new_rows = []
  38. rowspan = {}
  39. # deal with muti-row or multi col cells
  40. for i, row in enumerate(rows):
  41. new_row = []
  42. col_loc = 0
  43. j = 0
  44. while col_loc < num_cols:
  45. if col_loc in rowspan:
  46. val = rowspan[col_loc]
  47. val[-2] -= 1
  48. cur_col_loc = 0
  49. for _ in range(val[-1]):
  50. cur_col_loc += 1
  51. new_row.append(val[:5])
  52. if val[-2] == 1:
  53. del rowspan[col_loc]
  54. col_loc += cur_col_loc
  55. else:
  56. val = row[j]
  57. if val[-2] > 1: # new rowspan detected
  58. rowspan[col_loc] = val
  59. col_loc += val[-1] # usually 1
  60. for _ in range(val[-1]):
  61. new_row.append(val[:5])
  62. j += 1
  63. new_rows.append(new_row)
  64. return new_rows, num_header_rows
  65. def get_text_align(self, element):
  66. style = element.get("style", "").lower()
  67. if "text-align" in style:
  68. idx = style.find("text-align")
  69. text_align = style[idx + 10:].split(":")[1].strip()
  70. for val in ("left", "right", "center"):
  71. if text_align.startswith(val):
  72. return val
  73. def parse_into_rows(self, html):
  74. def get_property(class_name, property_name):
  75. for rule in sheet:
  76. selectors = rule.selectorText.replace(" ", "").split(",")
  77. if class_name in selectors:
  78. for style_property in rule.style:
  79. if style_property.name == property_name:
  80. return style_property.value
  81. def parse_row(row):
  82. values = []
  83. rowspan_dict = {}
  84. colspan_total = 0
  85. row_align = self.get_text_align(row)
  86. for el in row.find_all(["td", "th"]):
  87. bold = el.name == "th"
  88. colspan = int(el.attrs.get("colspan", 1))
  89. rowspan = int(el.attrs.get("rowspan", 1))
  90. text_align = self.get_text_align(el) or row_align
  91. text = el.get_text()
  92. if "id" in el.attrs:
  93. values.append(
  94. [
  95. text,
  96. bold,
  97. text_align,
  98. get_property("#" + el.attrs["id"], "background-color"),
  99. get_property("#" + el.attrs["id"], "color"),
  100. rowspan,
  101. colspan,
  102. ]
  103. )
  104. else:
  105. values.append(
  106. [text, bold, text_align, "#ffffff", "#000000", rowspan, colspan]
  107. )
  108. return values
  109. soup = BeautifulSoup(html, features="lxml")
  110. style = soup.find("style")
  111. # if style:
  112. # sheet = cssutils.parseString(style.text)
  113. # else:
  114. # sheet = []
  115. # get number of columns from first row
  116. # num_cols = sum(int(el.get('colspan', 1)) for el in soup.find('tr').find_all(['td', 'th']))
  117. thead = soup.find("thead")
  118. tbody = soup.find("tbody")
  119. rows = []
  120. if thead:
  121. head_rows = thead.find_all("tr")
  122. if head_rows:
  123. for row in head_rows:
  124. rows.append(parse_row(row))
  125. else:
  126. rows.append(parse_row(thead))
  127. num_header_rows = len(rows)
  128. if tbody:
  129. for row in tbody.find_all("tr"):
  130. rows.append(parse_row(row))
  131. if not thead and not tbody:
  132. for row in soup.find_all("tr"):
  133. rows.append(parse_row(row))
  134. return rows, num_header_rows
  135. def get_text_width(self, text, weight=None):
  136. fig = self.text_fig
  137. t = fig.text(0, 0, text, size=self.fontsize, weight=weight)
  138. bbox = t.get_window_extent(renderer=self.renderer)
  139. return bbox.width
  140. def get_all_text_widths(self, rows):
  141. all_text_widths = []
  142. for i, row in enumerate(rows):
  143. row_widths = []
  144. for vals in row:
  145. cell_max_width = 0
  146. for text in vals[0].split("\n"):
  147. weight = "bold" if i == 0 else None
  148. cell_max_width = max(
  149. cell_max_width, self.get_text_width(text, weight)
  150. )
  151. row_widths.append(cell_max_width)
  152. all_text_widths.append(row_widths)
  153. pad = 10 # number of pixels to pad columns with
  154. return np.array(all_text_widths) + 15
  155. def calculate_col_widths(self):
  156. all_text_widths = self.get_all_text_widths(self.rows)
  157. max_col_widths = all_text_widths.max(axis=0)
  158. mult = 1
  159. total_width = self.figwidth * self.dpi
  160. if self.for_document and sum(max_col_widths) >= total_width:
  161. while mult > 0.5:
  162. mult *= 0.9
  163. for idx in np.argsort(-max_col_widths):
  164. col_widths = all_text_widths[:, idx]
  165. rows = self.wrap_col(idx, col_widths, mult)
  166. all_text_widths = self.get_all_text_widths(rows)
  167. max_col_widths = all_text_widths.max(axis=0)
  168. if sum(max_col_widths) < total_width:
  169. break
  170. if mult <= 0.5 and self.fontsize > 12:
  171. self.fontsize *= 0.9
  172. return self.calculate_col_widths()
  173. else:
  174. self.rows = rows
  175. total_width = sum(max_col_widths)
  176. col_prop = [width / total_width for width in max_col_widths]
  177. return col_prop
  178. def wrap_col(self, idx, col_widths, mult):
  179. rows = self.rows.copy()
  180. max_width = max(col_widths)
  181. texts = [row[idx][0] for row in self.rows]
  182. new_texts = []
  183. new_max_width = 0
  184. for i, (text, col_width) in enumerate(zip(texts, col_widths)):
  185. if col_width > mult * max_width and len(text) > self.wrap_length:
  186. width = max(self.wrap_length, int(len(text) * mult))
  187. new_text = textwrap.fill(text, width, break_long_words=False)
  188. new_texts.append(new_text)
  189. new_max_width = max(new_max_width, self.get_text_width(new_text))
  190. else:
  191. new_texts.append(text)
  192. if new_max_width < max_width:
  193. for row, text in zip(rows, new_texts):
  194. row[idx][0] = text
  195. return rows
  196. def get_row_heights(self):
  197. row_heights = []
  198. for row in self.rows:
  199. row_count = max([val[0].count("\n") + 1 for val in row])
  200. height = (row_count + 1) * self.fontsize / 72
  201. row_heights.append(height)
  202. return row_heights
  203. def create_figure(self):
  204. figheight = sum(self.row_heights)
  205. fig = Figure(dpi=self.dpi, figsize=(self.figwidth, figheight))
  206. return fig
  207. def print_table(self):
  208. row_colors = ["#f5f5f5", "#ffffff"]
  209. # padding 0.5 em
  210. padding = self.fontsize / (self.figwidth * self.dpi) * 0.5
  211. total_width = sum(self.col_widths)
  212. figheight = self.fig.get_figheight()
  213. row_locs = [height / figheight for height in self.row_heights]
  214. header_text_align = [vals[2] for vals in self.rows[0]]
  215. x0 = (1 - total_width) / 2
  216. x = x0
  217. yd = row_locs[0]
  218. y = 1
  219. for i, (yd, row) in enumerate(zip(row_locs, self.rows)):
  220. x = x0
  221. y -= yd
  222. # table zebra stripes
  223. diff = i - self.num_header_rows
  224. if diff >= 0 and diff % 2 == 0:
  225. p = mpatches.Rectangle(
  226. (x0, y),
  227. width=total_width,
  228. height=yd,
  229. fill=True,
  230. color=row_colors[0],
  231. transform=self.fig.transFigure,
  232. )
  233. self.fig.add_artist(p)
  234. for j, (xd, val) in enumerate(zip(self.col_widths, row)):
  235. text = val[0]
  236. weight = "bold" if val[1] else None
  237. ha = val[2] or header_text_align[j] or "right"
  238. fg = val[4] if val[4] else "#000000"
  239. bg = val[3] if val[3] else None
  240. if bg:
  241. rect_bg = mpatches.Rectangle(
  242. (x, y),
  243. width=xd,
  244. height=yd,
  245. fill=True,
  246. color=bg,
  247. transform=self.fig.transFigure,
  248. )
  249. self.fig.add_artist(rect_bg)
  250. if ha == "right":
  251. x_pos = x + xd - padding
  252. elif ha == "center":
  253. x_pos = x + xd / 2
  254. elif ha == "left":
  255. x_pos = x + padding
  256. self.fig.text(
  257. x_pos,
  258. y + yd / 2,
  259. text,
  260. size=self.fontsize,
  261. ha=ha,
  262. va="center",
  263. weight=weight,
  264. color=fg,
  265. # backgroundcolor=bg
  266. )
  267. x += xd
  268. if i == self.num_header_rows - 1:
  269. line = mlines.Line2D([x0, x0 + total_width], [y, y], color="black")
  270. self.fig.add_artist(line)
  271. w, h = self.fig.get_size_inches()
  272. start = self.figwidth * min(x0, 0.1)
  273. end = self.figwidth - start
  274. bbox = Bbox([[start - 0.1, y * h], [end + 0.1, h]])
  275. buffer = io.BytesIO()
  276. self.fig.savefig(buffer, bbox_inches=bbox, dpi=self.savefig_dpi)
  277. img_str = buffer.getvalue()
  278. if self.encode_base64:
  279. img_str = base64.b64encode(img_str).decode()
  280. return img_str
  281. def run(self, df, filename: str):
  282. html = df.to_html(notebook=False) # notebook 控制单元内容超长后是否截断
  283. self.fontsize = self.original_fontsize
  284. self.text_fig = Figure(dpi=self.dpi)
  285. self.renderer = RendererAgg(self.figwidth, self.figheight, self.dpi)
  286. self.rows, self.num_header_rows = self.parse_html(html)
  287. self.col_widths = self.calculate_col_widths()
  288. self.row_heights = self.get_row_heights()
  289. self.fig = self.create_figure()
  290. img_str = self.print_table()
  291. with open(filename, "wb") as f:
  292. f.write(img_str)