Переглянути джерело

modify: 变量值类型统计

yq 1 місяць тому
батько
коміт
b849c4d67b

+ 2 - 2
commom/__init__.py

@@ -8,8 +8,8 @@ from .logger import get_logger
 from .placeholder_func import f_fill_placeholder
 from .user_exceptions import GeneralException
 from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df, f_format_float, \
-    f_df_to_image, f_display_images_by_side, NumpyEncoder, f_display_title, f_image_crop_white_borders
+    f_df_to_image, f_display_images_by_side, NumpyEncoder, f_display_title, f_image_crop_white_borders, f_is_number
 
 __all__ = ['f_get_clazz_in_module', 'f_clazz_to_json', 'GeneralException', 'get_logger', 'f_fill_placeholder',
            'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float', 'f_df_to_image',
-           'f_display_images_by_side', 'f_display_title', 'NumpyEncoder', 'f_image_crop_white_borders']
+           'f_display_images_by_side', 'f_display_title', 'NumpyEncoder', 'f_image_crop_white_borders', 'f_is_number']

+ 8 - 0
commom/utils.py

@@ -21,6 +21,14 @@ from config import BaseConfig
 from .matplotlib_table import TableMaker
 
 
+def f_is_number(s):
+    try:
+        float(s)
+        return True
+    except ValueError:
+        return False
+
+
 def f_format_float(num: float, n=3):
     return f"{num: .{n}f}"
 

+ 7 - 1
data/insight/data_explore.py

@@ -9,6 +9,7 @@ import numbers
 import pandas as pd
 from pandas.core.dtypes.common import is_numeric_dtype
 
+from commom import f_is_number
 from enums import ConstantEnum
 
 
@@ -28,15 +29,20 @@ class DataExplore():
                 cnt_str = 0
                 cnt_number = 0
                 cnt_other = 0
+                cnt_strnum = 0
                 for value in values:
                     if isinstance(value, numbers.Number):
                         cnt_number += 1
                     elif isinstance(value, str):
                         cnt_str += 1
+                        if f_is_number(value):
+                            cnt_strnum += 1
                     else:
                         cnt_other += 1
+
                 if len(values) != cnt_str:
-                    check_msg = f"{check_msg}【{column}】数值型数量{cnt_number} 字符型数量{cnt_str} 其它类型数量{cnt_other}\n"
+                    check_msg = f"{check_msg}【{column}】数值型数量{cnt_number} 字符型数量{cnt_str} 字符型数值数量{cnt_strnum} " \
+                                f"其它类型数量{cnt_other}\n"
         return check_msg
 
     @staticmethod

+ 5 - 0
entitys/data_feaure_entity.py

@@ -46,6 +46,11 @@ class DataSplitEntity():
     def __init__(self, train_data: pd.DataFrame, test_data: pd.DataFrame):
         self._train_data = train_data
         self._test_data = test_data
+        self._data = pd.concat((train_data, test_data))
+
+    @property
+    def data(self):
+        return self._data
 
     @property
     def train_data(self):

+ 5 - 4
feature/woe/strategy_woe.py

@@ -286,8 +286,6 @@ class StrategyWoe(FeatureStrategyBase):
 
     def _f_fast_filter(self, data: DataSplitEntity) -> Dict[str, BinInfo]:
         # 通过iv值粗筛变量
-        train_data = data.train_data
-        test_data = data.test_data
         y_column = self.ml_config.y_column
         x_columns = self.ml_config.x_columns
         columns_exclude = self.ml_config.columns_exclude
@@ -296,6 +294,10 @@ class StrategyWoe(FeatureStrategyBase):
         iv_threshold = self.ml_config.iv_threshold
         psi_threshold = self.ml_config.psi_threshold
 
+        train_data = data.train_data
+        test_data = data.test_data
+        data = data.data
+
         if len(x_columns) == 0:
             x_columns = train_data.columns.tolist()
         if y_column in x_columns:
@@ -303,8 +305,7 @@ class StrategyWoe(FeatureStrategyBase):
         for column in columns_exclude:
             if column in x_columns:
                 x_columns.remove(column)
-
-        check_msg = DataExplore.check_type(train_data[x_columns])
+        check_msg = DataExplore.check_type(data[x_columns])
         if check_msg != "":
             print(f"数据类型分析:\n{check_msg}\n同一变量请保持数据类型一致")
             raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"数据类型错误.")

+ 4 - 11
online_learning/trainer.py

@@ -128,8 +128,7 @@ class OnlineLearningTrainer:
 
         train_data = self._data.train_data
         test_data = self._data.test_data
-        data = pd.concat((train_data, test_data))
-
+        data = self._data.data
         model = self._model_optimized
         if model_type != "新模型":
             model = self._model_original
@@ -150,10 +149,8 @@ class OnlineLearningTrainer:
         return MetricFucResultEntity(table=df_auc_ks, image_path=img_path_auc_ks, image_size=5, table_font_size=10)
 
     def _f_get_metric_trend(self, ):
-        train_data = self._data.train_data
-        test_data = self._data.test_data
         y_column = self._ol_config.y_column
-        data = pd.concat((train_data, test_data))
+        data = self._data.data
 
         # 建模样本变量趋势
         breaks_list = {}
@@ -187,10 +184,8 @@ class OnlineLearningTrainer:
         return MetricFucResultEntity(table=df, image_path=img_path_coef)
 
     def _f_get_metric_gain(self, model_type: str):
-        train_data = self._data.train_data
-        test_data = self._data.test_data
         y_column = self._ol_config.y_column
-        data = pd.concat((train_data, test_data))
+        data = self._data.data
 
         model = self._model_optimized
         if model_type != "新模型":
@@ -207,10 +202,8 @@ class OnlineLearningTrainer:
     def _f_get_stress_test(self, ):
         stress_sample_times = self._ol_config.stress_sample_times
         stress_bad_rate_list = self._ol_config.stress_bad_rate_list
-        train_data = self._data.train_data
-        test_data = self._data.test_data
         y_column = self._ol_config.y_column
-        data = pd.concat((train_data, test_data))
+        data = self._data.data
         score = self.prob(data, self._model_optimized)
         score_bin, _ = f_get_model_score_bin(data, score)
         df_stress = f_stress_test(score_bin, sample_times=stress_sample_times, bad_rate_list=stress_bad_rate_list,