فهرست منبع

add: 保存配置和图片裁边

yq 1 ماه پیش
والد
کامیت
8954604e8c
6فایلهای تغییر یافته به همراه65 افزوده شده و 8 حذف شده
  1. 2 2
      commom/__init__.py
  2. 34 0
      commom/utils.py
  3. 12 1
      entitys/ml_config_entity.py
  4. 3 1
      feature/woe/strategy_woe.py
  5. 3 1
      model/model_lr.py
  6. 11 3
      pipeline/pipeline.py

+ 2 - 2
commom/__init__.py

@@ -8,8 +8,8 @@ from .logger import get_logger
 from .placeholder_func import f_fill_placeholder
 from .user_exceptions import GeneralException
 from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df, f_format_float, \
-    f_df_to_image, f_display_images_by_side, NumpyEncoder, f_display_title
+    f_df_to_image, f_display_images_by_side, NumpyEncoder, f_display_title, f_image_crop_white_borders
 
 __all__ = ['f_get_clazz_in_module', 'f_clazz_to_json', 'GeneralException', 'get_logger', 'f_fill_placeholder',
            'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float', 'f_df_to_image',
-           'f_display_images_by_side', 'f_display_title', 'NumpyEncoder']
+           'f_display_images_by_side', 'f_display_title', 'NumpyEncoder', 'f_image_crop_white_borders']

+ 34 - 0
commom/utils.py

@@ -15,6 +15,7 @@ from typing import Union
 import numpy as np
 import pandas as pd
 import pytz
+from PIL import Image
 
 from config import BaseConfig
 from .matplotlib_table import TableMaker
@@ -93,6 +94,39 @@ def _f_image_to_base64(image_path):
         return img_str.decode("utf-8")
 
 
+def f_image_crop_white_borders(image_path, output_path):
+    # 打开图片
+    image = Image.open(image_path)
+    # 将图片转换为灰度图
+    gray_image = image.convert('L')
+    # 获取图片的宽度和高度
+    width, height = gray_image.size
+    # 初始化边界
+    left, top, right, bottom = width, height, 0, 0
+
+    # 遍历图片的每一行和每一列
+    for y in range(height):
+        for x in range(width):
+            # 获取当前像素的灰度值
+            pixel = gray_image.getpixel((x, y))
+            # 如果像素不是白色(灰度值小于 255)
+            if pixel < 255:
+                # 更新边界
+                if x < left:
+                    left = x
+                if x > right:
+                    right = x
+                if y < top:
+                    top = y
+                if y > bottom:
+                    bottom = y
+
+    # 裁剪图片
+    cropped_image = image.crop((left, top, right + 1, bottom + 1))
+    # 保存裁剪后的图片
+    cropped_image.save(output_path)
+
+
 def f_display_images_by_side(display, image_path_list, title: str = "", width: int = 500,
                              image_path_list2: Union[list, None] = None, title2: str = "", ):
     if isinstance(image_path_list, str):

+ 12 - 1
entitys/ml_config_entity.py

@@ -124,7 +124,6 @@ class MlConfigEntity():
         # 加减分规则
         self._rules = rules
 
-
         if self._project_name is None or len(self._project_name) == 0:
             self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
         else:
@@ -133,6 +132,7 @@ class MlConfigEntity():
         self._include = columns_include + list(self.breaks_list.keys())
 
         os.makedirs(self._base_dir, exist_ok=True)
+        print(f"项目路径:【{self._base_dir}】")
 
         if self._jupyter_print:
             warning_ignore()
@@ -294,6 +294,9 @@ class MlConfigEntity():
         """
         从配置文件生成实体类
         """
+        if os.path.isdir(config_path):
+            config_path = os.path.join(config_path, "mlcfg.json")
+
         if os.path.exists(config_path):
             with open(config_path, mode="r", encoding="utf-8") as f:
                 j = json.loads(f.read())
@@ -302,6 +305,14 @@ class MlConfigEntity():
 
         return MlConfigEntity(**j)
 
+    def config_save(self):
+        path = self.f_get_save_path("mlcfg.json")
+        with open(path, mode="w", encoding="utf-8") as f:
+            j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
+            j = json.dumps(j, ensure_ascii=False)
+            f.write(j)
+        print(f"mlcfg save to【{path}】success. ")
+
 
 if __name__ == "__main__":
     pass

+ 3 - 1
feature/woe/strategy_woe.py

@@ -17,7 +17,8 @@ import seaborn as sns
 from pandas.core.dtypes.common import is_numeric_dtype
 from tqdm import tqdm
 
-from commom import f_display_images_by_side, NumpyEncoder, GeneralException, f_df_to_image, f_display_title
+from commom import f_display_images_by_side, NumpyEncoder, GeneralException, f_df_to_image, f_display_title, \
+    f_image_crop_white_borders
 from entitys import DataSplitEntity, MetricFucResultEntity
 from enums import ContextEnum, ResultCodesEnum
 from feature.feature_strategy_base import FeatureStrategyBase
@@ -44,6 +45,7 @@ class StrategyWoe(FeatureStrategyBase):
         plt.xticks(rotation=90)
         img_path = self.ml_config.f_get_save_path(f"corr.png")
         plt.savefig(img_path)
+        f_image_crop_white_borders(img_path, img_path)
         return img_path
 
     def _f_get_img_trend(self, sc_woebin, x_columns, prefix):

+ 3 - 1
model/model_lr.py

@@ -14,7 +14,8 @@ import pandas as pd
 import scorecardpy as sc
 import statsmodels.api as sm
 
-from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title
+from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title, \
+    f_image_crop_white_borders
 from entitys import MetricFucResultEntity, DataSplitEntity, DataFeatureEntity
 from enums import ContextEnum, ResultCodesEnum, ConstantEnum
 from init import context
@@ -107,6 +108,7 @@ class ModelLr(ModelBase):
             perf["pic"].savefig(path)
             auc = perf["AUC"]
             ks = perf["KS"]
+            f_image_crop_white_borders(path, path)
             return auc, ks, path
 
         def _get_perf(perf_rule=False):

+ 11 - 3
pipeline/pipeline.py

@@ -52,18 +52,26 @@ class Pipeline():
     def score(self, data: pd.DataFrame):
         return self._model.score(data)
 
+    def score_rule(self, data: pd.DataFrame):
+        return self._model.score_rule(data)
+
     def report(self, ):
         save_path = self._ml_config.f_get_save_path("模型报告.docx")
         Report.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
         print(f"模型报告文件储存路径:{save_path}")
 
     def save(self):
+        self._ml_config.config_save()
         self._feature_strategy.feature_save()
         self._model.model_save()
 
-    def load(self, path: str):
-        self._feature_strategy.feature_load(path)
-        self._model.model_load(path)
+    @staticmethod
+    def load(path: str):
+        ml_config = MlConfigEntity.from_config(path)
+        pipeline = Pipeline(ml_config=ml_config)
+        pipeline._feature_strategy.feature_load(path)
+        pipeline._model.model_load(path)
+        return pipeline
 
     def variable_analyse(self, column: str, format_bin=None):
         self._feature_strategy.variable_analyse(self._data, column, format_bin)