2 次代碼提交 5f36d1e476 ... de62450c54

作者 SHA1 備註 提交日期
  yq de62450c54 modify: xgb代码优化 1 月之前
  yq 0cc78a917e modify: 代码优化 1 月之前
共有 8 個文件被更改,包括 154 次插入55 次删除
  1. 1 0
      __init__.py
  2. 5 3
      enums/context_enum.py
  3. 3 0
      feature/bin/strategy_norm.py
  4. 16 18
      feature/woe/strategy_woe.py
  5. 7 0
      init/__init__.py
  6. 116 29
      model/model_xgb.py
  7. 5 2
      pipeline/pipeline.py
  8. 1 3
      requirements-analysis.txt

+ 1 - 0
__init__.py

@@ -4,6 +4,7 @@
 @time: 2024/10/31
 @time: 2024/10/31
 @desc:
 @desc:
 """
 """
+
 import sys
 import sys
 from os.path import dirname, realpath
 from os.path import dirname, realpath
 
 

+ 5 - 3
enums/context_enum.py

@@ -8,12 +8,14 @@ from enum import Enum
 
 
 
 
 class ContextEnum(Enum):
 class ContextEnum(Enum):
-    PARAM_OPTIMIZED = "param_optimized"
-    BIN_INFO_FILTERED = "bin_info_filtered"
-    HOMO_BIN_INFO_NUMERIC_SET = "homo_bin_info_numeric_set"
+    # 分箱信息
     WOEBIN = "woebin"
     WOEBIN = "woebin"
+
     FILTER_FAST = "filter_fast"
     FILTER_FAST = "filter_fast"
     FILTER_NUMERIC = "filter_numeric"
     FILTER_NUMERIC = "filter_numeric"
     FILTER_CORR = "filter_corr"
     FILTER_CORR = "filter_corr"
     FILTER_VIF = "filter_vif"
     FILTER_VIF = "filter_vif"
     FILTER_IVTOP = "filter_ivtop"
     FILTER_IVTOP = "filter_ivtop"
+
+    XGB_COLUMNS_STR = "xgb_columns_str"
+    XGB_COLUMNS_NUM = "xgb_columns_num"

+ 3 - 0
feature/bin/strategy_norm.py

@@ -132,6 +132,9 @@ class StrategyNorm(FeatureStrategyBase):
                                 f"筛选前变量数量:{len(x_columns)}\n{x_columns}\n"
                                 f"筛选前变量数量:{len(x_columns)}\n{x_columns}\n"
                                 f"快速筛选剔除变量数量:{len(x_columns) - len(x_columns_filter)}", detail=df_importance)
                                 f"快速筛选剔除变量数量:{len(x_columns) - len(x_columns_filter)}", detail=df_importance)
 
 
+        context.set(ContextEnum.XGB_COLUMNS_STR, str_columns)
+        context.set(ContextEnum.XGB_COLUMNS_NUM, num_columns)
+
         return x_columns_filter
         return x_columns_filter
 
 
     def feature_search(self, data: DataSplitEntity, *args, **kwargs):
     def feature_search(self, data: DataSplitEntity, *args, **kwargs):

+ 16 - 18
feature/woe/strategy_woe.py

@@ -33,6 +33,8 @@ class StrategyWoe(FeatureStrategyBase):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         # woe编码需要的分箱信息,复用scorecardpy的格式
         # woe编码需要的分箱信息,复用scorecardpy的格式
         self.sc_woebin = None
         self.sc_woebin = None
+        self._bin_info_filtered: Dict[str, BinInfo]
+        self._homo_bin_info_numeric_set: Dict[str, HomologousBinInfo]
 
 
     def _f_get_img_corr(self, train_woe) -> Union[str, None]:
     def _f_get_img_corr(self, train_woe) -> Union[str, None]:
         if len(train_woe.columns.to_list()) <= 1:
         if len(train_woe.columns.to_list()) <= 1:
@@ -422,17 +424,17 @@ class StrategyWoe(FeatureStrategyBase):
         bin_info_filtered = self._f_vif_filter(data, bin_info_filtered)
         bin_info_filtered = self._f_vif_filter(data, bin_info_filtered)
         bin_info_filtered = BinInfo.ivTopN(bin_info_filtered, max_feature_num)
         bin_info_filtered = BinInfo.ivTopN(bin_info_filtered, max_feature_num)
         self.sc_woebin = self._f_get_sc_woebin(data.train_data, bin_info_filtered)
         self.sc_woebin = self._f_get_sc_woebin(data.train_data, bin_info_filtered)
-        context.set(ContextEnum.BIN_INFO_FILTERED, bin_info_filtered)
         context.set(ContextEnum.WOEBIN, self.sc_woebin)
         context.set(ContextEnum.WOEBIN, self.sc_woebin)
+        return bin_info_filtered
 
 
     def feature_search(self, data: DataSplitEntity, *args, **kwargs):
     def feature_search(self, data: DataSplitEntity, *args, **kwargs):
         # 粗筛
         # 粗筛
         bin_info_fast = self._f_fast_filter(data)
         bin_info_fast = self._f_fast_filter(data)
         x_columns = list(bin_info_fast.keys())
         x_columns = list(bin_info_fast.keys())
 
 
-        bin_info_filtered: Dict[str, BinInfo] = {}
+        self._bin_info_filtered: Dict[str, BinInfo] = {}
         # 数值型变量多种分箱方式的中间结果
         # 数值型变量多种分箱方式的中间结果
-        homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = {}
+        self._homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = {}
         filter_numeric_overview = ""
         filter_numeric_overview = ""
         filter_numeric_detail = []
         filter_numeric_detail = []
         for x_column in tqdm(x_columns):
         for x_column in tqdm(x_columns):
@@ -440,22 +442,21 @@ class StrategyWoe(FeatureStrategyBase):
                 # 数值型变量筛选
                 # 数值型变量筛选
                 homo_bin_info_numeric: HomologousBinInfo = self._handle_numeric(data, x_column)
                 homo_bin_info_numeric: HomologousBinInfo = self._handle_numeric(data, x_column)
                 if homo_bin_info_numeric.is_auto_bins:
                 if homo_bin_info_numeric.is_auto_bins:
-                    homo_bin_info_numeric_set[x_column] = homo_bin_info_numeric
+                    self._homo_bin_info_numeric_set[x_column] = homo_bin_info_numeric
                 # iv psi 变量单调性 变量趋势一致性 筛选
                 # iv psi 变量单调性 变量趋势一致性 筛选
                 bin_info: Optional[BinInfo] = homo_bin_info_numeric.filter()
                 bin_info: Optional[BinInfo] = homo_bin_info_numeric.filter()
                 if bin_info is not None:
                 if bin_info is not None:
-                    bin_info_filtered[x_column] = bin_info
+                    self._bin_info_filtered[x_column] = bin_info
                 else:
                 else:
                     # 不满足要求被剔除
                     # 不满足要求被剔除
                     filter_numeric_overview = f"{filter_numeric_overview}{x_column} {homo_bin_info_numeric.drop_reason()}\n"
                     filter_numeric_overview = f"{filter_numeric_overview}{x_column} {homo_bin_info_numeric.drop_reason()}\n"
                     filter_numeric_detail.append(x_column)
                     filter_numeric_detail.append(x_column)
             else:
             else:
                 # 字符型暂时用scorecardpy来处理
                 # 字符型暂时用scorecardpy来处理
-                bin_info_filtered[x_column] = bin_info_fast[x_column]
+                self._bin_info_filtered[x_column] = bin_info_fast[x_column]
 
 
-        self.post_filter(data, bin_info_filtered)
+        self._bin_info_filtered = self.post_filter(data, self._bin_info_filtered)
 
 
-        context.set(ContextEnum.HOMO_BIN_INFO_NUMERIC_SET, homo_bin_info_numeric_set)
         context.set_filter_info(ContextEnum.FILTER_NUMERIC, filter_numeric_overview, filter_numeric_detail)
         context.set_filter_info(ContextEnum.FILTER_NUMERIC, filter_numeric_overview, filter_numeric_detail)
 
 
     def variable_analyse(self, data: DataSplitEntity, column: str, format_bin=None, *args, **kwargs):
     def variable_analyse(self, data: DataSplitEntity, column: str, format_bin=None, *args, **kwargs):
@@ -497,7 +498,6 @@ class StrategyWoe(FeatureStrategyBase):
         train_data = data.train_data
         train_data = data.train_data
         test_data = data.test_data
         test_data = data.test_data
         # 跨模块调用中间结果,所以从上下文里取
         # 跨模块调用中间结果,所以从上下文里取
-        bin_info_filtered: Dict[str, BinInfo] = context.get(ContextEnum.BIN_INFO_FILTERED)
 
 
         metric_value_dict = {}
         metric_value_dict = {}
         # 样本分布
         # 样本分布
@@ -505,15 +505,15 @@ class StrategyWoe(FeatureStrategyBase):
                                                           table_cell_width=3)
                                                           table_cell_width=3)
 
 
         # 变量相关性
         # 变量相关性
-        sc_woebin_train = self._f_get_sc_woebin(train_data, bin_info_filtered)
+        sc_woebin_train = self._f_get_sc_woebin(train_data, self._bin_info_filtered)
         train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin_train, print_info=False)
         train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin_train, print_info=False)
         img_path_corr = self._f_get_img_corr(train_woe)
         img_path_corr = self._f_get_img_corr(train_woe)
         metric_value_dict["变量相关性"] = MetricFucResultEntity(image_path=img_path_corr)
         metric_value_dict["变量相关性"] = MetricFucResultEntity(image_path=img_path_corr)
 
 
         # 变量iv、psi、vif
         # 变量iv、psi、vif
         df_iv_psi_vif = pd.DataFrame()
         df_iv_psi_vif = pd.DataFrame()
-        train_iv = [bin_info_filtered[column].train_iv for column in x_columns]
-        psi = [bin_info_filtered[column].psi for column in x_columns]
+        train_iv = [self._bin_info_filtered[column].train_iv for column in x_columns]
+        psi = [self._bin_info_filtered[column].psi for column in x_columns]
         anns = [columns_anns.get(column, "-") for column in x_columns]
         anns = [columns_anns.get(column, "-") for column in x_columns]
         df_iv_psi_vif["变量"] = x_columns
         df_iv_psi_vif["变量"] = x_columns
         df_iv_psi_vif["iv"] = train_iv
         df_iv_psi_vif["iv"] = train_iv
@@ -535,7 +535,7 @@ class StrategyWoe(FeatureStrategyBase):
         metric_value_dict["变量趋势-训练集"] = MetricFucResultEntity(image_path=imgs_path_trend_train, image_size=4)
         metric_value_dict["变量趋势-训练集"] = MetricFucResultEntity(image_path=imgs_path_trend_train, image_size=4)
 
 
         # 变量趋势-测试集
         # 变量趋势-测试集
-        sc_woebin_test = self._f_get_sc_woebin(test_data, bin_info_filtered)
+        sc_woebin_test = self._f_get_sc_woebin(test_data, self._bin_info_filtered)
         imgs_path_trend_test = self._f_get_img_trend(sc_woebin_test, x_columns, "test")
         imgs_path_trend_test = self._f_get_img_trend(sc_woebin_test, x_columns, "test")
         metric_value_dict["变量趋势-测试集"] = MetricFucResultEntity(image_path=imgs_path_trend_test, image_size=4)
         metric_value_dict["变量趋势-测试集"] = MetricFucResultEntity(image_path=imgs_path_trend_test, image_size=4)
 
 
@@ -554,7 +554,7 @@ class StrategyWoe(FeatureStrategyBase):
                 detail = [detail]
                 detail = [detail]
             if isinstance(detail, list):
             if isinstance(detail, list):
                 for column in detail:
                 for column in detail:
-                    homo_bin_info_numeric = homo_bin_info_numeric_set.get(column)
+                    homo_bin_info_numeric = self._homo_bin_info_numeric_set.get(column)
                     if homo_bin_info_numeric is None:
                     if homo_bin_info_numeric is None:
                         continue
                         continue
                     self._f_best_bins_print(display, data, column, homo_bin_info_numeric)
                     self._f_best_bins_print(display, data, column, homo_bin_info_numeric)
@@ -572,8 +572,6 @@ class StrategyWoe(FeatureStrategyBase):
             if detail is not None and self.ml_config.bin_detail_print:
             if detail is not None and self.ml_config.bin_detail_print:
                 detail_print(detail)
                 detail_print(detail)
 
 
-        bin_info_filtered: Dict[str, BinInfo] = context.get(ContextEnum.BIN_INFO_FILTERED)
-        homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = context.get(ContextEnum.HOMO_BIN_INFO_NUMERIC_SET)
         filter_fast = context.get(ContextEnum.FILTER_FAST)
         filter_fast = context.get(ContextEnum.FILTER_FAST)
         filter_numeric = context.get(ContextEnum.FILTER_NUMERIC)
         filter_numeric = context.get(ContextEnum.FILTER_NUMERIC)
         filter_corr = context.get(ContextEnum.FILTER_CORR)
         filter_corr = context.get(ContextEnum.FILTER_CORR)
@@ -597,11 +595,11 @@ class StrategyWoe(FeatureStrategyBase):
                                  title2="测试集")
                                  title2="测试集")
 
 
         # 打印breaks_list
         # 打印breaks_list
-        breaks_list = {column: bin_info.points for column, bin_info in bin_info_filtered.items()}
+        breaks_list = {column: bin_info.points for column, bin_info in self._bin_info_filtered.items()}
         print("变量切分点:")
         print("变量切分点:")
         print(json.dumps(breaks_list, ensure_ascii=False, indent=2, cls=NumpyEncoder))
         print(json.dumps(breaks_list, ensure_ascii=False, indent=2, cls=NumpyEncoder))
         print("选中变量不同分箱数下变量的推荐切分点:")
         print("选中变量不同分箱数下变量的推荐切分点:")
-        detail_print(list(bin_info_filtered.keys()))
+        detail_print(list(self._bin_info_filtered.keys()))
 
 
         # 打印fast_filter筛选情况
         # 打印fast_filter筛选情况
         filter_print(filter_fast, "快速筛选过程", "剔除train_iv小于阈值")
         filter_print(filter_fast, "快速筛选过程", "剔除train_iv小于阈值")

+ 7 - 0
init/__init__.py

@@ -4,6 +4,7 @@
 @time: 2024/10/31
 @time: 2024/10/31
 @desc: 一些资源初始化
 @desc: 一些资源初始化
 """
 """
+import os
 import sys
 import sys
 import threading
 import threading
 
 
@@ -12,6 +13,12 @@ from contextvars import ContextVar
 
 
 from config import BaseConfig
 from config import BaseConfig
 
 
+if BaseConfig.java_home is not None:
+    java_home = BaseConfig.java_home
+    if os.path.basename(java_home) != "bin":
+        java_home = os.path.join(java_home, 'bin')
+    os.environ['PATH'] = f"{os.environ['PATH']}:{java_home}"
+
 matplotlib.use('Agg')
 matplotlib.use('Agg')
 
 
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt

+ 116 - 29
model/model_xgb.py

@@ -9,27 +9,85 @@ import os.path
 from os.path import dirname, realpath
 from os.path import dirname, realpath
 from typing import Dict
 from typing import Dict
 
 
+import joblib
+import numpy
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
 import scorecardpy as sc
 import scorecardpy as sc
 import xgboost as xgb
 import xgboost as xgb
-from sklearn2pmml import sklearn2pmml, make_pmml_pipeline
+from pandas import DataFrame, Series
+from sklearn.preprocessing import OneHotEncoder
+from sklearn2pmml import sklearn2pmml, PMMLPipeline
+from sklearn2pmml.preprocessing import CutTransformer
+from sklearn_pandas import DataFrameMapper
 
 
 from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title, \
 from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title, \
     f_image_crop_white_borders
     f_image_crop_white_borders
-from config import BaseConfig
 from entitys import MetricFucResultEntity, DataSplitEntity, DataFeatureEntity
 from entitys import MetricFucResultEntity, DataSplitEntity, DataFeatureEntity
-from enums import ResultCodesEnum, ConstantEnum, FileEnum
+from enums import ResultCodesEnum, ConstantEnum, FileEnum, ContextEnum
+from init import context
 from .model_base import ModelBase
 from .model_base import ModelBase
 from .model_utils import f_stress_test, f_calcu_model_ks, f_get_model_score_bin, f_calcu_model_psi
 from .model_utils import f_stress_test, f_calcu_model_ks, f_get_model_score_bin, f_calcu_model_psi
 
 
 
 
+class PMMLPipelineXgb(PMMLPipeline):
+    def __init__(self, steps, ):
+        super().__init__(steps=steps)
+
+    def _filter_column_names(self, X):
+        return (numpy.asarray(X)).astype(str)
+
+    def _get_column_names(self, X):
+        if isinstance(X, DataFrame):
+            return self._filter_column_names(X.columns.values)
+        elif isinstance(X, Series):
+            return self._filter_column_names(X.name)
+        # elif isinstance(X, H2OFrame)
+        elif hasattr(X, "names"):
+            return self._filter_column_names(X.names)
+        else:
+            return None
+
+    def Xtransformer_fit(self, X, y=None):
+        # Collect feature name(s)
+        active_fields = self._get_column_names(X)
+        if active_fields is not None:
+            self.active_fields = active_fields
+        # Collect label name(s)
+        target_fields = self._get_column_names(y)
+        if target_fields is not None:
+            self.target_fields = target_fields
+
+        self.steps = list(self.steps)
+        self._validate_steps()
+
+        for (step_idx, name, transformer) in self._iter(with_final=False, filter_passthrough=False):
+            transformer.fit(X)
+            self.steps[step_idx] = (name, transformer)
+
+    def Xtransform(self, X):
+        Xt = X
+        for name, transform in self.steps[:-1]:
+            if transform is not None:
+                Xt = transform.transform(Xt)
+        return Xt
+
+    def fit(self, X, y=None, **fit_params):
+        fit_params_steps = self._check_fit_params(**fit_params)
+        Xt = self.Xtransform(X)
+        if self._final_estimator != 'passthrough':
+            fit_params_last_step = fit_params_steps[self.steps[-1][0]]
+            self._final_estimator.fit(Xt, y, **fit_params_last_step)
+        return self
+
+
 class ModelXgb(ModelBase):
 class ModelXgb(ModelBase):
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         # 报告模板
         # 报告模板
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
-        self.model = None
+        self.pipeline: PMMLPipelineXgb
+        self.model = xgb.XGBClassifier
 
 
     def get_report_template_path(self):
     def get_report_template_path(self):
         return self._template_path
         return self._template_path
@@ -37,7 +95,15 @@ class ModelXgb(ModelBase):
     def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity, *args, **kwargs):
     def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity, *args, **kwargs):
         print(f"{'-' * 50}开始训练{'-' * 50}")
         print(f"{'-' * 50}开始训练{'-' * 50}")
         params_xgb = self.ml_config.params_xgb
         params_xgb = self.ml_config.params_xgb
+        y_column = self._ml_config.y_column
+        str_columns = context.get(ContextEnum.XGB_COLUMNS_STR)
+        num_columns = context.get(ContextEnum.XGB_COLUMNS_NUM)
 
 
+        data: DataSplitEntity = kwargs["data"]
+        train_data_raw = data.train_data
+        test_data_raw = data.test_data
+
+        # xgb原生接口训练
         # dtrain = xgb.DMatrix(data=train_data.data_x, label=train_data.data_y)
         # dtrain = xgb.DMatrix(data=train_data.data_x, label=train_data.data_y)
         # dtest = xgb.DMatrix(data=test_data.data_x, label=test_data.data_y)
         # dtest = xgb.DMatrix(data=test_data.data_x, label=test_data.data_y)
         # self.model = xgb.train(
         # self.model = xgb.train(
@@ -49,6 +115,7 @@ class ModelXgb(ModelBase):
         #     verbose_eval=params_xgb.get("verbose_eval")
         #     verbose_eval=params_xgb.get("verbose_eval")
         # )
         # )
 
 
+        # xgb二次封装为sklearn接口
         self.model = xgb.XGBClassifier(objective=params_xgb.get("objective"),
         self.model = xgb.XGBClassifier(objective=params_xgb.get("objective"),
                                        n_estimators=params_xgb.get("num_boost_round"),
                                        n_estimators=params_xgb.get("num_boost_round"),
                                        max_depth=params_xgb.get("max_depth"),
                                        max_depth=params_xgb.get("max_depth"),
@@ -60,25 +127,43 @@ class ModelXgb(ModelBase):
                                        importance_type='weight'
                                        importance_type='weight'
                                        )
                                        )
 
 
-        self.model.fit(X=train_data.data_x, y=train_data.data_y,
-                       eval_set=[(train_data.data_x, train_data.data_y), (test_data.data_x, test_data.data_y)],
-                       eval_metric=params_xgb.get("eval_metric"),
-                       early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
-                       verbose=params_xgb.get("verbose_eval"),
-                       )
-
-        if params_xgb.get("trees_print"):
-            trees = self.model.get_booster().get_dump()
-            for i, tree in enumerate(trees):
-                if i < self.model.best_ntree_limit:
-                    print(f"Tree {i}:")
-                    print(tree)
-
-        self._train_score = self.prob(train_data.data_x)
-        self._test_score = self.prob(test_data.data_x)
+        # self.model.fit(X=train_data.data_x, y=train_data.data_y,
+        #                eval_set=[(train_data.data_x, train_data.data_y), (test_data.data_x, test_data.data_y)],
+        #                eval_metric=params_xgb.get("eval_metric"),
+        #                early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
+        #                verbose=params_xgb.get("verbose_eval"),
+        #                )
+
+        # if params_xgb.get("trees_print"):
+        #     trees = self.model.get_booster().get_dump()
+        #     for i, tree in enumerate(trees):
+        #         if i < self.model.best_ntree_limit:
+        #             print(f"Tree {i}:")
+        #             print(tree)
+
+        mapper = [(str_columns, OneHotEncoder())]
+        # for column in str_columns:
+        #     mapper.append((column, OneHotEncoder()))
+        for column in num_columns:
+            mapper.append(
+                (column, CutTransformer([-np.inf, 10, 20, 30, +np.inf], labels=[1, 2, 3, 4])))
+        mapper = DataFrameMapper(mapper)
+
+        self.pipeline = PMMLPipelineXgb([("mapper", mapper), ("classifier", self.model)])
+        self.pipeline.Xtransformer_fit(data.data, data.data[y_column])
+        self.pipeline.fit(train_data_raw, train_data_raw[y_column],
+                          classifier__eval_set=[
+                              (self.pipeline.Xtransform(train_data_raw), train_data_raw[y_column]),
+                              (self.pipeline.Xtransform(test_data_raw), test_data_raw[y_column])
+                          ],
+                          classifier__eval_metric=params_xgb.get("eval_metric"),
+                          classifier__early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
+                          classifier__verbose=params_xgb.get("verbose_eval"),
+                          )
 
 
     def prob(self, x: pd.DataFrame, *args, **kwargs) -> np.array:
     def prob(self, x: pd.DataFrame, *args, **kwargs) -> np.array:
-        prob = self.model.predict_proba(x)[:, 1]
+        # prob = self.model.predict_proba(x)[:, 1]
+        prob = self.pipeline.predict_proba(x)[:, 1]
         return prob
         return prob
 
 
     def score(self, x: pd.DataFrame, *args, **kwargs) -> np.array:
     def score(self, x: pd.DataFrame, *args, **kwargs) -> np.array:
@@ -88,16 +173,17 @@ class ModelXgb(ModelBase):
         pass
         pass
 
 
     def model_save(self):
     def model_save(self):
-        if self.model is None:
+        if self.pipeline is None:
             GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型不存在")
             GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型不存在")
 
 
         path_model = self.ml_config.f_get_save_path(FileEnum.MODEL.value)
         path_model = self.ml_config.f_get_save_path(FileEnum.MODEL.value)
-        self.model.save_model(path_model)
+        # self.model.save_model(path_model)
+        joblib.dump(self.pipeline, path_model)
         print(f"model save to【{path_model}】success. ")
         print(f"model save to【{path_model}】success. ")
 
 
         path_pmml = self.ml_config.f_get_save_path(FileEnum.PMML.value)
         path_pmml = self.ml_config.f_get_save_path(FileEnum.PMML.value)
-        pipeline = make_pmml_pipeline(self.model)
-        sklearn2pmml(pipeline, path_pmml, with_repr=True, java_home=BaseConfig.java_home)
+        # pipeline = make_pmml_pipeline(self.model)
+        sklearn2pmml(self.pipeline, path_pmml, with_repr=True, )
         print(f"model save to【{path_pmml}】success. ")
         print(f"model save to【{path_pmml}】success. ")
 
 
     def model_load(self, path: str, *args, **kwargs):
     def model_load(self, path: str, *args, **kwargs):
@@ -107,8 +193,9 @@ class ModelXgb(ModelBase):
         if not os.path.isfile(path_model):
         if not os.path.isfile(path_model):
             raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型文件【{path_model}】不存在")
             raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型文件【{path_model}】不存在")
 
 
-        self.model = xgb.XGBClassifier()
-        self.model.load_model(path_model)
+        # self.model = xgb.XGBClassifier()
+        # self.model.load_model(path_model)
+        self.pipeline = joblib.load(path_model)
 
 
         print(f"model load from【{path_model}】success.")
         print(f"model load from【{path_model}】success.")
 
 
@@ -127,8 +214,8 @@ class ModelXgb(ModelBase):
             # 模型ks auc
             # 模型ks auc
             img_path_auc_ks = []
             img_path_auc_ks = []
 
 
-            train_score = self._train_score
-            test_score = self._test_score
+            train_score = self.prob(train_data)
+            test_score = self.prob(test_data)
 
 
             train_auc, train_ks, path = _get_auc_ks(train_data[y_column], train_score, f"train")
             train_auc, train_ks, path = _get_auc_ks(train_data[y_column], train_score, f"train")
             img_path_auc_ks.append(path)
             img_path_auc_ks.append(path)

+ 5 - 2
pipeline/pipeline.py

@@ -40,13 +40,16 @@ class Pipeline():
         train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
         train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
         test_data = self._feature_strategy.feature_generate(self._data.test_data)
         test_data = self._feature_strategy.feature_generate(self._data.test_data)
         test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
         test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
-        self._model.train(train_data, test_data)
+        self._model.train(train_data, test_data, data=self._data)
         metric_model = self._model.train_report(self._data)
         metric_model = self._model.train_report(self._data)
 
 
         self.metric_value_dict = {**metric_feature, **metric_model}
         self.metric_value_dict = {**metric_feature, **metric_model}
 
 
     def prob(self, data: pd.DataFrame):
     def prob(self, data: pd.DataFrame):
-        feature = self._feature_strategy.feature_generate(data)
+        if self._ml_config.model_type == ModelEnum.XGB.value:
+            feature = data
+        else:
+            feature = self._feature_strategy.feature_generate(data)
         prob = self._model.prob(feature)
         prob = self._model.prob(feature)
         return prob
         return prob
 
 

+ 1 - 3
requirements-analysis.txt

@@ -21,7 +21,5 @@ pypmml==0.9.0
 #dataframe_image==0.1.14
 #dataframe_image==0.1.14
 #thrift-sasl==0.4.3
 #thrift-sasl==0.4.3
 #pyhive==0.7.0
 #pyhive==0.7.0
-#sklearn2pmml==0.103.3
+#sklearn2pmml==0.65.0
 #sklearn-pandas==2.2.0
 #sklearn-pandas==2.2.0
-#dill==0.3.4
-