Przeglądaj źródła

modify: 报告格式调整

yq 4 miesięcy temu
rodzic
commit
708b03455d

+ 2 - 2
commom/utils.py

@@ -16,8 +16,8 @@ import pytz
 from config import BaseConfig
 
 
-def f_format_float(num: float):
-    return f"{num: .3f}"
+def f_format_float(num: float, n=3):
+    return f"{num: .{n}f}"
 
 
 def f_get_date(offset: int = 0, connect: str = "-") -> str:

+ 5 - 5
entitys/data_feaure_entity.py

@@ -117,12 +117,12 @@ class DataSplitEntity():
 
         df["样本"] = ["训练集", "测试集", "合计"]
         df["样本数"] = [train_data_len, test_data_len, total]
-        df["样本占比"] = [f"{f_format_float(train_data_len / total * 100)}%",
-                      f"{f_format_float(test_data_len / total * 100)}%", "100%"]
+        df["样本占比"] = [f"{f_format_float(train_data_len / total * 100, 2)}%",
+                      f"{f_format_float(test_data_len / total * 100, 2)}%", "100%"]
         df["坏样本数"] = [train_bad_len, test_bad_len, bad_total]
-        df["坏样本比例"] = [f"{f_format_float(train_bad_len / train_data_len * 100)}%",
-                       f"{f_format_float(test_bad_len / test_data_len * 100)}%",
-                       f"{f_format_float(bad_total / total * 100)}%"]
+        df["坏样本比例"] = [f"{f_format_float(train_bad_len / train_data_len * 100, 2)}%",
+                       f"{f_format_float(test_bad_len / test_data_len * 100, 2)}%",
+                       f"{f_format_float(bad_total / total * 100, 2)}%"]
 
         return df
     

+ 1 - 9
entitys/data_process_config_entity.py

@@ -8,8 +8,7 @@ import json
 import os
 from typing import List, Union
 
-from commom import GeneralException, f_get_datetime
-from config import BaseConfig
+from commom import GeneralException
 from enums import ResultCodesEnum
 
 
@@ -19,9 +18,6 @@ class DataProcessConfigEntity():
                  iv_threshold: float = 0.03, iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4,
                  sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list] = None):
 
-        self.save_path = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
-        os.makedirs(self.save_path, exist_ok=True)
-
         # 定义y变量
         self._y_column = y_column
 
@@ -125,10 +121,6 @@ class DataProcessConfigEntity():
 
         return DataProcessConfigEntity(**j)
 
-    def _get_save_path(self, file_name: str) -> str:
-        path = os.path.join(self.save_path, file_name)
-        return path
-
 
 if __name__ == "__main__":
     pass

+ 17 - 1
entitys/metric_entity.py

@@ -52,12 +52,28 @@ class MetricFucEntity():
     """
 
     def __init__(self, table: pd.DataFrame = None, value: str = None, image_path: Union[str, list] = None,
-                 image_size: int = 6):
+                 table_font_size=12, table_autofit=False, table_cell_width=None, image_size: int = 6):
         self._table = table
+        self._table_font_size = table_font_size
+        self._table_cell_width= table_cell_width
+        self._table_autofit = table_autofit
+
         self._value = value
         self._image_path = image_path
         self._image_size = image_size
 
+    @property
+    def table_cell_width(self):
+        return self._table_cell_width
+
+    @property
+    def table_autofit(self):
+        return self._table_autofit
+
+    @property
+    def table_font_size(self):
+        return self._table_font_size
+
     @property
     def table(self) -> pd.DataFrame:
         return self._table

+ 6 - 4
feature/strategy_iv.py

@@ -15,6 +15,7 @@ import seaborn as sns
 from pandas.core.dtypes.common import is_numeric_dtype
 
 from entitys import DataSplitEntity, CandidateFeatureEntity, DataPreparedEntity, DataFeatureEntity, MetricFucEntity
+from init import f_get_save_path
 from .feature_utils import f_judge_monto, f_get_corr
 from .filter_strategy_base import FilterStrategyBase
 
@@ -33,7 +34,7 @@ class StrategyIv(FilterStrategyBase):
             # bin_df["bin"] = bin_df["bin"].apply(lambda x: re.sub(r"(\d+\.\d+)",
             #                                                      lambda m: "{:.2f}".format(float(m.group(0))), x))
             sc.woebin_plot(bin_df)
-            path = self.data_process_config._get_save_path(f"{prefix}_{k}.png")
+            path = f_get_save_path(f"{prefix}_{k}.png")
             plt.savefig(path)
             image_path_list.append(path)
         return image_path_list
@@ -327,7 +328,8 @@ class StrategyIv(FilterStrategyBase):
 
         metric_value_dict = {}
         # 样本分布
-        metric_value_dict["样本分布"] = MetricFucEntity(table=data.get_distribution(y_column))
+        metric_value_dict["样本分布"] = MetricFucEntity(table=data.get_distribution(y_column), table_font_size=12,
+                                                    table_cell_width=3)
         # 变量iv及psi
         train_bins = self._f_get_bins_by_breaks(train_data, candidate_dict)
         train_iv = {key_: [round(value_['total_iv'].max(), 4)] for key_, value_ in train_bins.items()}
@@ -350,7 +352,7 @@ class StrategyIv(FilterStrategyBase):
             image_path_list = self._f_save_var_trend(test_bins, x_columns_candidate, "test")
             metric_value_dict["变量趋势-测试集"] = MetricFucEntity(image_path=image_path_list, image_size=4)
 
-        metric_value_dict["变量iv"] = MetricFucEntity(table=train_iv)
+        metric_value_dict["变量iv"] = MetricFucEntity(table=train_iv, table_font_size=12, table_cell_width=3)
         # 变量趋势-训练集
         image_path_list = self._f_save_var_trend(train_bins, x_columns_candidate, "train")
         metric_value_dict["变量趋势-训练集"] = MetricFucEntity(image_path=image_path_list, image_size=4)
@@ -362,7 +364,7 @@ class StrategyIv(FilterStrategyBase):
         plt.title('Variables Correlation', fontsize=15)
         plt.yticks(rotation=0)
         plt.xticks(rotation=90)
-        path = self.data_process_config._get_save_path(f"var_corr.png")
+        path = f_get_save_path(f"var_corr.png")
         plt.savefig(path)
         metric_value_dict["变量有效性"] = MetricFucEntity(image_path=path)
 

+ 17 - 1
init/__init__.py

@@ -2,8 +2,24 @@
 """
 @author: yq
 @time: 2024/10/31
-@desc: 模型及指标计算类初始化
+@desc: 一些资源初始化
 """
 
+import os
+
+from commom import f_get_datetime
+from config import BaseConfig
+
+__all__ = ['f_get_save_path']
+
+save_path = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
+os.makedirs(save_path, exist_ok=True)
+
+
+def f_get_save_path(file_name: str) -> str:
+    path = os.path.join(save_path, file_name)
+    return path
+
+
 if __name__ == "__main__":
     pass

+ 34 - 26
monitor/report_generate.py

@@ -12,7 +12,7 @@ from docx.enum.table import WD_ALIGN_VERTICAL
 from docx.enum.text import WD_ALIGN_PARAGRAPH
 from docx.oxml import OxmlElement
 from docx.oxml.ns import qn
-from docx.shared import Inches, Cm
+from docx.shared import Inches, Cm, Pt
 
 from commom import GeneralException, f_get_datetime
 from config import BaseConfig
@@ -23,32 +23,37 @@ from enums import ResultCodesEnum, PlaceholderPrefixEnum
 class Report():
 
     @staticmethod
-    def _set_cell_width(cell):
-        text = cell.text
-        if len(text) >= 10:
-            cell.width = Cm(2)
-        elif len(text) >= 15:
-            cell.width = Cm(2.5)
-        elif len(text) >= 25:
-            cell.width = Cm(3)
-        else:
-            cell.width = Cm(1.5)
+    def _set_cell_width(table, table_cell_width):
+        for column in table.columns:
+            if table_cell_width is not None:
+                column.width = Cm(table_cell_width)
+            # elif len(text) >= 10:
+            #     cell.width = Cm(2)
+            # elif len(text) >= 15:
+            #     cell.width = Cm(2.5)
+            # elif len(text) >= 25:
+            #     cell.width = Cm(3)
+            # else:
+            #     cell.width = Cm(1.5)
 
     @staticmethod
-    def _set_cell_format(cell):
-        cell.paragraphs[0].alignment = WD_ALIGN_PARAGRAPH.CENTER
+    def _set_cell_format(cell, font_size=None):
+        for paragraph in cell.paragraphs:
+            paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER
+            for run in paragraph.runs:
+                if font_size is not None:
+                    run.font.size = Pt(font_size)
         cell.vertical_alignment = WD_ALIGN_VERTICAL.CENTER
 
     @staticmethod
-    def _merge_cell_column(pre_cell, curr_cell):
+    def _merge_cell_column(pre_cell, curr_cell, table_font_size, table_cell_width):
         if curr_cell.text == pre_cell.text:
             column_name = curr_cell.text
             pre_cell.merge(curr_cell)
             pre_cell.text = column_name
             for run in pre_cell.paragraphs[0].runs:
                 run.bold = True
-            Report._set_cell_format(pre_cell)
-            Report._set_cell_width(pre_cell)
+            Report._set_cell_format(pre_cell, table_font_size)
 
     @staticmethod
     def _set_table_singleBoard(table):
@@ -115,6 +120,9 @@ class Report():
             for metric_code, metric_fuc_entity in metric_value_dict.items():
                 placeholder = Report._get_placeholder(PlaceholderPrefixEnum.TABLE, metric_code)
                 metric_table = metric_fuc_entity.table
+                table_font_size = metric_fuc_entity.table_font_size
+                table_autofit = metric_fuc_entity.table_autofit
+                table_cell_width = metric_fuc_entity.table_cell_width
                 if metric_table is None:
                     continue
                 if not placeholder in paragraph.text:
@@ -131,26 +139,26 @@ class Report():
                     cell.text = str(column_name)
                     for run in cell.paragraphs[0].runs:
                         run.bold = True
-                    Report._set_cell_format(cell)
-                    Report._set_cell_width(cell)
+                    Report._set_cell_format(cell, table_font_size)
                     # 合并相同的列名
                     if column_idx != 0 and BaseConfig.merge_table_column:
                         pre_cell = table.cell(0, column_idx - 1)
-                        Report._merge_cell_column(pre_cell, cell)
+                        Report._merge_cell_column(pre_cell, cell, table_font_size, table_cell_width)
                 # 值
                 for row_idx, row in metric_table.iterrows():
                     for column_idx, value in enumerate(row):
                         cell = table.cell(row_idx + 1, column_idx)
                         cell.text = str(value)
-                        Report._set_cell_format(cell)
-                        Report._set_cell_width(cell)
+                        Report._set_cell_format(cell, table_font_size)
                         # 合并第一行数据也为列的情况
                         if row_idx == 0:
-                            Report._merge_cell_column(table.cell(0, column_idx), cell)
+                            Report._merge_cell_column(table.cell(0, column_idx), cell, table_font_size,
+                                                      table_cell_width)
 
+                Report._set_cell_width(table, table_cell_width)
                 Report._set_table_singleBoard(table)
                 # 禁止自动调整表格
-                if len(metric_table.columns) <= 12:
+                if len(metric_table.columns) <= 12 or not table_autofit:
                     table.autofit = False
 
     @staticmethod
@@ -179,7 +187,7 @@ class Report():
                         run.add_picture(path, width=Inches(image_size))
 
     @staticmethod
-    def generate_report(metric_value_dict: Dict[str, MetricFucEntity], template_path: str, path=None):
+    def generate_report(metric_value_dict: Dict[str, MetricFucEntity], template_path: str, save_path=None):
         if os.path.exists(template_path):
             doc = Document(template_path)
         else:
@@ -189,8 +197,8 @@ class Report():
         Report._fill_table_placeholder(doc, metric_value_dict)
         Report._fill_image_placeholder(doc, metric_value_dict)
         new_path = template_path.replace(".docx", f"{f_get_datetime()}.docx")
-        if path is not None:
-            new_path = path
+        if save_path is not None:
+            new_path = save_path
         doc.save(f"./{new_path}")
 
 

BIN
template/模型开发报告模板_lr.docx


+ 2 - 2
trainer/train.py

@@ -7,6 +7,7 @@
 from typing import Dict
 
 from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity
+from init import f_get_save_path
 from model import f_get_model
 from monitor.report_generate import Report
 
@@ -22,8 +23,7 @@ class TrainPipeline():
         print(metric_train)
 
     def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]):
-        Report.generate_report(metric_value_dict, self._train_config.template_path, path="模型报告.docx")
-        pass
+        Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx"))
 
 
 if __name__ == "__main__":