yq 3 сар өмнө
parent
commit
16548ddb19

+ 4 - 4
__init__.py

@@ -7,16 +7,16 @@
 import sys
 from os.path import dirname, realpath
 
+sys.path.append(dirname(realpath(__file__)))
+
 from feature import FilterStrategyFactory
 from model import ModelFactory
 from trainer import TrainPipeline
 
-sys.path.append(dirname(realpath(__file__)))
-
 from data import DataLoaderMysql
-from entitys import DbConfigEntity
+from entitys import DbConfigEntity, DataSplitEntity
 from monitor import MonitorMetric
 from metrics import MetricBase
 
 __all__ = ['MonitorMetric', 'DataLoaderMysql', 'DbConfigEntity', 'MetricBase', 'FilterStrategyFactory', 'ModelFactory',
-           'TrainPipeline']
+           'TrainPipeline', 'DataSplitEntity']

+ 3 - 2
commom/__init__.py

@@ -8,7 +8,8 @@ from .logger import get_logger
 from .placeholder_func import f_fill_placeholder
 from .user_exceptions import GeneralException
 from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df, f_format_float, \
-    f_df_to_image
+    f_df_to_image, f_display_images_by_side
 
 __all__ = ['f_get_clazz_in_module', 'f_clazz_to_json', 'GeneralException', 'get_logger', 'f_fill_placeholder',
-           'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float', 'f_df_to_image']
+           'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float', 'f_df_to_image',
+           'f_display_images_by_side']

+ 19 - 1
commom/utils.py

@@ -9,7 +9,7 @@ import datetime
 import inspect
 import os
 from json import JSONEncoder
-
+import base64
 import dataframe_image as dfi
 import pandas as pd
 import pytz
@@ -51,6 +51,24 @@ def f_df_to_image(df, filename, fontsize=12):
     dfi.export(obj=df, filename=filename, fontsize=fontsize, table_conversion='matplotlib')
 
 
+def _f_image_to_base64(image_path):
+    with open(image_path, "rb") as image_file:
+        img_str = base64.b64encode(image_file.read())
+        return img_str.decode("utf-8")
+
+
+def f_display_images_by_side(image_path_list, display, title: str = "", width: int = 500):
+    if isinstance(image_path_list, str):
+        image_path_list = [image_path_list]
+    html_str = '<div style="display:flex; justify-content:space-around;">'
+    if title != "":
+        html_str += '<h3>{}</h3>'.format(title)
+    for image_path in image_path_list:
+        html_str += f'<img src="data:image/png;base64,{_f_image_to_base64(image_path)}" style="width:{width}px;"/>'
+    html_str += '</div>'
+    display.display(display.HTML(html_str))
+
+
 class f_clazz_to_json(JSONEncoder):
     def default(self, o):
         return o.__dict__

+ 15 - 2
feature/strategy_iv.py

@@ -4,6 +4,7 @@
 @time: 2024/1/2
 @desc: iv值及单调性筛选类
 """
+import time
 from itertools import combinations_with_replacement
 from typing import List, Dict
 
@@ -15,6 +16,7 @@ import seaborn as sns
 from pandas.core.dtypes.common import is_numeric_dtype
 from tqdm import tqdm
 
+from commom import f_display_images_by_side
 from entitys import DataSplitEntity, CandidateFeatureEntity, DataPreparedEntity, DataFeatureEntity, MetricFucEntity
 from .feature_utils import f_judge_monto, f_get_corr, f_get_ivf, f_format_bin
 from .filter_strategy_base import FilterStrategyBase
@@ -350,8 +352,8 @@ class StrategyIv(FilterStrategyBase):
         return DataPreparedEntity(train_data_feature, val_data_feature, test_data_feature, bins=bins,
                                   data_split_original=data)
 
-    def feature_report(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity], *args,
-                       **kwargs) -> Dict[str, MetricFucEntity]:
+    def feature_report(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity], jupyter=False,
+                       *args, **kwargs) -> Dict[str, MetricFucEntity]:
         y_column = self.data_process_config.y_column
         x_columns_candidate = list(candidate_dict.keys())
         train_data = data.train_data
@@ -390,4 +392,15 @@ class StrategyIv(FilterStrategyBase):
         vif_df = f_get_ivf(train_woe)
         metric_value_dict["变量有效性"] = MetricFucEntity(image_path=var_corr_image_path, table=vif_df)
 
+        time.sleep(3)
+        if jupyter:
+            from IPython import display
+            display.display(metric_value_dict["样本分布"].table)
+            display.display(metric_value_dict["变量iv"].table)
+            f_display_images_by_side(metric_value_dict["变量有效性"].image_path, display, width=800)
+            f_display_images_by_side(metric_value_dict["变量趋势-训练集"].image_path, display, title="变量趋势训练集")
+            metric_test = metric_value_dict.get("变量趋势-测试集")
+            if metric_test is not None:
+                f_display_images_by_side(metric_test.image_path, display, title="变量趋势测试集")
+
         return metric_value_dict

+ 3 - 1
model/model_lr.py

@@ -4,6 +4,8 @@
 @time: 2024/11/1
 @desc: 
 """
+import os.path
+from os.path import dirname, realpath
 from typing import Dict
 
 import pandas as pd
@@ -20,7 +22,7 @@ class ModelLr(ModelBase):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         # 报告模板
-        self._template_path = "./template/模型开发报告模板_lr.docx"
+        self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_lr.docx")
         self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
 
     def get_template_path(self):