Browse Source

add: 保存配置和图片裁边

yq 3 months ago
parent
commit
8954604e8c
6 changed files with 65 additions and 8 deletions
  1. 2 2
      commom/__init__.py
  2. 34 0
      commom/utils.py
  3. 12 1
      entitys/ml_config_entity.py
  4. 3 1
      feature/woe/strategy_woe.py
  5. 3 1
      model/model_lr.py
  6. 11 3
      pipeline/pipeline.py

+ 2 - 2
commom/__init__.py

@@ -8,8 +8,8 @@ from .logger import get_logger
 from .placeholder_func import f_fill_placeholder
 from .placeholder_func import f_fill_placeholder
 from .user_exceptions import GeneralException
 from .user_exceptions import GeneralException
 from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df, f_format_float, \
 from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df, f_format_float, \
-    f_df_to_image, f_display_images_by_side, NumpyEncoder, f_display_title
+    f_df_to_image, f_display_images_by_side, NumpyEncoder, f_display_title, f_image_crop_white_borders
 
 
 __all__ = ['f_get_clazz_in_module', 'f_clazz_to_json', 'GeneralException', 'get_logger', 'f_fill_placeholder',
 __all__ = ['f_get_clazz_in_module', 'f_clazz_to_json', 'GeneralException', 'get_logger', 'f_fill_placeholder',
            'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float', 'f_df_to_image',
            'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float', 'f_df_to_image',
-           'f_display_images_by_side', 'f_display_title', 'NumpyEncoder']
+           'f_display_images_by_side', 'f_display_title', 'NumpyEncoder', 'f_image_crop_white_borders']

+ 34 - 0
commom/utils.py

@@ -15,6 +15,7 @@ from typing import Union
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
 import pytz
 import pytz
+from PIL import Image
 
 
 from config import BaseConfig
 from config import BaseConfig
 from .matplotlib_table import TableMaker
 from .matplotlib_table import TableMaker
@@ -93,6 +94,39 @@ def _f_image_to_base64(image_path):
         return img_str.decode("utf-8")
         return img_str.decode("utf-8")
 
 
 
 
+def f_image_crop_white_borders(image_path, output_path):
+    # 打开图片
+    image = Image.open(image_path)
+    # 将图片转换为灰度图
+    gray_image = image.convert('L')
+    # 获取图片的宽度和高度
+    width, height = gray_image.size
+    # 初始化边界
+    left, top, right, bottom = width, height, 0, 0
+
+    # 遍历图片的每一行和每一列
+    for y in range(height):
+        for x in range(width):
+            # 获取当前像素的灰度值
+            pixel = gray_image.getpixel((x, y))
+            # 如果像素不是白色(灰度值小于 255)
+            if pixel < 255:
+                # 更新边界
+                if x < left:
+                    left = x
+                if x > right:
+                    right = x
+                if y < top:
+                    top = y
+                if y > bottom:
+                    bottom = y
+
+    # 裁剪图片
+    cropped_image = image.crop((left, top, right + 1, bottom + 1))
+    # 保存裁剪后的图片
+    cropped_image.save(output_path)
+
+
 def f_display_images_by_side(display, image_path_list, title: str = "", width: int = 500,
 def f_display_images_by_side(display, image_path_list, title: str = "", width: int = 500,
                              image_path_list2: Union[list, None] = None, title2: str = "", ):
                              image_path_list2: Union[list, None] = None, title2: str = "", ):
     if isinstance(image_path_list, str):
     if isinstance(image_path_list, str):

+ 12 - 1
entitys/ml_config_entity.py

@@ -124,7 +124,6 @@ class MlConfigEntity():
         # 加减分规则
         # 加减分规则
         self._rules = rules
         self._rules = rules
 
 
-
         if self._project_name is None or len(self._project_name) == 0:
         if self._project_name is None or len(self._project_name) == 0:
             self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
             self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
         else:
         else:
@@ -133,6 +132,7 @@ class MlConfigEntity():
         self._include = columns_include + list(self.breaks_list.keys())
         self._include = columns_include + list(self.breaks_list.keys())
 
 
         os.makedirs(self._base_dir, exist_ok=True)
         os.makedirs(self._base_dir, exist_ok=True)
+        print(f"项目路径:【{self._base_dir}】")
 
 
         if self._jupyter_print:
         if self._jupyter_print:
             warning_ignore()
             warning_ignore()
@@ -294,6 +294,9 @@ class MlConfigEntity():
         """
         """
         从配置文件生成实体类
         从配置文件生成实体类
         """
         """
+        if os.path.isdir(config_path):
+            config_path = os.path.join(config_path, "mlcfg.json")
+
         if os.path.exists(config_path):
         if os.path.exists(config_path):
             with open(config_path, mode="r", encoding="utf-8") as f:
             with open(config_path, mode="r", encoding="utf-8") as f:
                 j = json.loads(f.read())
                 j = json.loads(f.read())
@@ -302,6 +305,14 @@ class MlConfigEntity():
 
 
         return MlConfigEntity(**j)
         return MlConfigEntity(**j)
 
 
+    def config_save(self):
+        path = self.f_get_save_path("mlcfg.json")
+        with open(path, mode="w", encoding="utf-8") as f:
+            j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
+            j = json.dumps(j, ensure_ascii=False)
+            f.write(j)
+        print(f"mlcfg save to【{path}】success. ")
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     pass
     pass

+ 3 - 1
feature/woe/strategy_woe.py

@@ -17,7 +17,8 @@ import seaborn as sns
 from pandas.core.dtypes.common import is_numeric_dtype
 from pandas.core.dtypes.common import is_numeric_dtype
 from tqdm import tqdm
 from tqdm import tqdm
 
 
-from commom import f_display_images_by_side, NumpyEncoder, GeneralException, f_df_to_image, f_display_title
+from commom import f_display_images_by_side, NumpyEncoder, GeneralException, f_df_to_image, f_display_title, \
+    f_image_crop_white_borders
 from entitys import DataSplitEntity, MetricFucResultEntity
 from entitys import DataSplitEntity, MetricFucResultEntity
 from enums import ContextEnum, ResultCodesEnum
 from enums import ContextEnum, ResultCodesEnum
 from feature.feature_strategy_base import FeatureStrategyBase
 from feature.feature_strategy_base import FeatureStrategyBase
@@ -44,6 +45,7 @@ class StrategyWoe(FeatureStrategyBase):
         plt.xticks(rotation=90)
         plt.xticks(rotation=90)
         img_path = self.ml_config.f_get_save_path(f"corr.png")
         img_path = self.ml_config.f_get_save_path(f"corr.png")
         plt.savefig(img_path)
         plt.savefig(img_path)
+        f_image_crop_white_borders(img_path, img_path)
         return img_path
         return img_path
 
 
     def _f_get_img_trend(self, sc_woebin, x_columns, prefix):
     def _f_get_img_trend(self, sc_woebin, x_columns, prefix):

+ 3 - 1
model/model_lr.py

@@ -14,7 +14,8 @@ import pandas as pd
 import scorecardpy as sc
 import scorecardpy as sc
 import statsmodels.api as sm
 import statsmodels.api as sm
 
 
-from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title
+from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title, \
+    f_image_crop_white_borders
 from entitys import MetricFucResultEntity, DataSplitEntity, DataFeatureEntity
 from entitys import MetricFucResultEntity, DataSplitEntity, DataFeatureEntity
 from enums import ContextEnum, ResultCodesEnum, ConstantEnum
 from enums import ContextEnum, ResultCodesEnum, ConstantEnum
 from init import context
 from init import context
@@ -107,6 +108,7 @@ class ModelLr(ModelBase):
             perf["pic"].savefig(path)
             perf["pic"].savefig(path)
             auc = perf["AUC"]
             auc = perf["AUC"]
             ks = perf["KS"]
             ks = perf["KS"]
+            f_image_crop_white_borders(path, path)
             return auc, ks, path
             return auc, ks, path
 
 
         def _get_perf(perf_rule=False):
         def _get_perf(perf_rule=False):

+ 11 - 3
pipeline/pipeline.py

@@ -52,18 +52,26 @@ class Pipeline():
     def score(self, data: pd.DataFrame):
     def score(self, data: pd.DataFrame):
         return self._model.score(data)
         return self._model.score(data)
 
 
+    def score_rule(self, data: pd.DataFrame):
+        return self._model.score_rule(data)
+
     def report(self, ):
     def report(self, ):
         save_path = self._ml_config.f_get_save_path("模型报告.docx")
         save_path = self._ml_config.f_get_save_path("模型报告.docx")
         Report.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
         Report.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
         print(f"模型报告文件储存路径:{save_path}")
         print(f"模型报告文件储存路径:{save_path}")
 
 
     def save(self):
     def save(self):
+        self._ml_config.config_save()
         self._feature_strategy.feature_save()
         self._feature_strategy.feature_save()
         self._model.model_save()
         self._model.model_save()
 
 
-    def load(self, path: str):
-        self._feature_strategy.feature_load(path)
-        self._model.model_load(path)
+    @staticmethod
+    def load(path: str):
+        ml_config = MlConfigEntity.from_config(path)
+        pipeline = Pipeline(ml_config=ml_config)
+        pipeline._feature_strategy.feature_load(path)
+        pipeline._model.model_load(path)
+        return pipeline
 
 
     def variable_analyse(self, column: str, format_bin=None):
     def variable_analyse(self, column: str, format_bin=None):
         self._feature_strategy.variable_analyse(self._data, column, format_bin)
         self._feature_strategy.variable_analyse(self._data, column, format_bin)