utils.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # -*- coding:utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2023/12/28
  5. @desc: 各种工具类
  6. """
  7. import base64
  8. import datetime
  9. import inspect
  10. import os
  11. import sys
  12. from contextlib import contextmanager
  13. from json import JSONEncoder
  14. from typing import Union
  15. import numpy as np
  16. import pandas as pd
  17. import pytz
  18. from PIL import Image
  19. from config import BaseConfig
  20. from .matplotlib_table import TableMaker
  21. def f_is_number(s):
  22. try:
  23. float(s)
  24. return True
  25. except ValueError:
  26. return False
  27. def f_format_float(num: float, n=3):
  28. return f"{num: .{n}f}"
  29. def f_get_date(offset: int = 0, connect: str = "-") -> str:
  30. current_date = datetime.datetime.now(pytz.timezone("Asia/Shanghai")).date() + datetime.timedelta(days=offset)
  31. return current_date.strftime(f"%Y{connect}%m{connect}%d")
  32. def f_get_datetime(offset: int = 0, connect: str = "_") -> str:
  33. current_date = datetime.datetime.now(pytz.timezone("Asia/Shanghai")) + datetime.timedelta(days=offset)
  34. return current_date.strftime(f"%Y{connect}%m{connect}%d{connect}%H{connect}%M{connect}%S")
  35. def f_get_clazz_in_module(module):
  36. """
  37. 获取包下的所有类
  38. """
  39. classes = []
  40. for name, member in inspect.getmembers(module):
  41. if inspect.isclass(member):
  42. classes.append(member)
  43. return classes
  44. def f_save_train_df(file_name: str, df: pd.DataFrame):
  45. file_path = os.path.join(BaseConfig.train_path, file_name)
  46. df.to_excel(f"{file_path}.xlsx", index=False)
  47. def f_df_to_image(df: pd.DataFrame, filename, fontsize=12):
  48. converter = TableMaker(fontsize=fontsize, encode_base64=False, for_document=False)
  49. converter.run(df, filename)
  50. # if importlib.util.find_spec("dataframe_image"):
  51. # import dataframe_image as dfi
  52. #
  53. # dfi.export(obj=df, filename=filename, fontsize=fontsize, table_conversion='matplotlib')
  54. # elif importlib.util.find_spec("plotly"):
  55. # import plotly.graph_objects as go
  56. # import plotly.figure_factory as ff
  57. # import plotly.io as pio
  58. #
  59. # fig = ff.create_table(df)
  60. # fig.update_layout()
  61. # fig.write_image(filename)
  62. #
  63. # fig = go.Figure(data=go.Table(
  64. # header=dict(
  65. # values=df.columns.to_list(),
  66. # font=dict(color='black', size=fontsize),
  67. # fill_color="white",
  68. # line_color='black',
  69. # align="center"
  70. # ),
  71. # cells=dict(
  72. # values=[df[k].tolist() for k in df.columns],
  73. # font=dict(color='black', size=fontsize),
  74. # fill_color="white",
  75. # line_color='black',
  76. # align="center")
  77. # )).update_layout()
  78. # pio.write_image(fig, filename)
  79. # else:
  80. # raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"缺少画图依赖【dataframe_image】或者【plotly】")
  81. def _f_image_to_base64(image_path):
  82. with open(image_path, "rb") as image_file:
  83. img_str = base64.b64encode(image_file.read())
  84. return img_str.decode("utf-8")
  85. def f_image_crop_white_borders(image_path, output_path):
  86. # 打开图片
  87. image = Image.open(image_path)
  88. # 将图片转换为灰度图
  89. gray_image = image.convert('L')
  90. # 获取图片的宽度和高度
  91. width, height = gray_image.size
  92. # 初始化边界
  93. left, top, right, bottom = width, height, 0, 0
  94. # 遍历图片的每一行和每一列
  95. for y in range(height):
  96. for x in range(width):
  97. # 获取当前像素的灰度值
  98. pixel = gray_image.getpixel((x, y))
  99. # 如果像素不是白色(灰度值小于 255)
  100. if pixel < 255:
  101. # 更新边界
  102. if x < left:
  103. left = x
  104. if x > right:
  105. right = x
  106. if y < top:
  107. top = y
  108. if y > bottom:
  109. bottom = y
  110. # 裁剪图片
  111. cropped_image = image.crop((left, top, right + 1, bottom + 1))
  112. # 保存裁剪后的图片
  113. cropped_image.save(output_path)
  114. def f_display_images_by_side(display, image_path_list, title: str = "", width: int = 500,
  115. image_path_list2: Union[list, None] = None, title2: str = "", ):
  116. if isinstance(image_path_list, str):
  117. image_path_list = [image_path_list]
  118. # justify-content:space-around; 会导致某些情况下图片越界
  119. html_str = '<div style="display:flex;">'
  120. if title != "":
  121. html_str += '<div>{}</div>'.format(title)
  122. for image_path in image_path_list:
  123. html_str += f'<img src="data:image/png;base64,{_f_image_to_base64(image_path)}" style="width:{width}px;"/>'
  124. html_str += '</div>'
  125. if not (image_path_list2 is None or len(image_path_list2) == 0):
  126. html_str += '<div style="display:flex;">'
  127. if title2 != "":
  128. html_str += '<div>{}</div>'.format(title2)
  129. for image_path in image_path_list2:
  130. html_str += f'<img src="data:image/png;base64,{_f_image_to_base64(image_path)}" style="width:{width}px;"/>'
  131. html_str += '</div>'
  132. display.display(display.HTML(html_str))
  133. def f_display_title(display, title):
  134. html_str = f"<h2>{title}</h2>"
  135. display.display(display.HTML(html_str))
  136. class f_clazz_to_json(JSONEncoder):
  137. def default(self, o):
  138. return o.__dict__
  139. class NumpyEncoder(JSONEncoder):
  140. def default(self, obj):
  141. if isinstance(obj, np.integer):
  142. return int(obj)
  143. if isinstance(obj, np.floating):
  144. return float(obj)
  145. if isinstance(obj, np.ndarray):
  146. return obj.tolist()
  147. return super(NumpyEncoder, self).default(obj)
  148. @contextmanager
  149. def silent_print():
  150. original_stdout = sys.stdout
  151. class NullWriter:
  152. def write(self, text):
  153. pass
  154. null_writer = NullWriter()
  155. sys.stdout = null_writer
  156. try:
  157. yield
  158. finally:
  159. sys.stdout = original_stdout
  160. @contextmanager
  161. def df_print_nolimit():
  162. max_columns = pd.get_option('display.max_columns')
  163. max_rows = pd.get_option('display.max_rows')
  164. pd.set_option('display.max_columns', None)
  165. pd.set_option('display.max_rows', None)
  166. try:
  167. yield
  168. finally:
  169. pd.set_option('display.max_columns', max_columns)
  170. pd.set_option('display.max_rows', max_rows)