Pārlūkot izejas kodu

modify: xgb代码优化

yq 4 dienas atpakaļ
vecāks
revīzija
de62450c54

+ 1 - 0
__init__.py

@@ -4,6 +4,7 @@
 @time: 2024/10/31
 @desc:
 """
+
 import sys
 from os.path import dirname, realpath
 

+ 5 - 3
enums/context_enum.py

@@ -8,12 +8,14 @@ from enum import Enum
 
 
 class ContextEnum(Enum):
-    PARAM_OPTIMIZED = "param_optimized"
-    BIN_INFO_FILTERED = "bin_info_filtered"
-    HOMO_BIN_INFO_NUMERIC_SET = "homo_bin_info_numeric_set"
+    # 分箱信息
     WOEBIN = "woebin"
+
     FILTER_FAST = "filter_fast"
     FILTER_NUMERIC = "filter_numeric"
     FILTER_CORR = "filter_corr"
     FILTER_VIF = "filter_vif"
     FILTER_IVTOP = "filter_ivtop"
+
+    XGB_COLUMNS_STR = "xgb_columns_str"
+    XGB_COLUMNS_NUM = "xgb_columns_num"

+ 3 - 0
feature/bin/strategy_norm.py

@@ -132,6 +132,9 @@ class StrategyNorm(FeatureStrategyBase):
                                 f"筛选前变量数量:{len(x_columns)}\n{x_columns}\n"
                                 f"快速筛选剔除变量数量:{len(x_columns) - len(x_columns_filter)}", detail=df_importance)
 
+        context.set(ContextEnum.XGB_COLUMNS_STR, str_columns)
+        context.set(ContextEnum.XGB_COLUMNS_NUM, num_columns)
+
         return x_columns_filter
 
     def feature_search(self, data: DataSplitEntity, *args, **kwargs):

+ 7 - 0
init/__init__.py

@@ -4,6 +4,7 @@
 @time: 2024/10/31
 @desc: 一些资源初始化
 """
+import os
 import sys
 import threading
 
@@ -12,6 +13,12 @@ from contextvars import ContextVar
 
 from config import BaseConfig
 
+if BaseConfig.java_home is not None:
+    java_home = BaseConfig.java_home
+    if os.path.basename(java_home) != "bin":
+        java_home = os.path.join(java_home, 'bin')
+    os.environ['PATH'] = f"{os.environ['PATH']}:{java_home}"
+
 matplotlib.use('Agg')
 
 import matplotlib.pyplot as plt

+ 116 - 29
model/model_xgb.py

@@ -9,27 +9,85 @@ import os.path
 from os.path import dirname, realpath
 from typing import Dict
 
+import joblib
+import numpy
 import numpy as np
 import pandas as pd
 import scorecardpy as sc
 import xgboost as xgb
-from sklearn2pmml import sklearn2pmml, make_pmml_pipeline
+from pandas import DataFrame, Series
+from sklearn.preprocessing import OneHotEncoder
+from sklearn2pmml import sklearn2pmml, PMMLPipeline
+from sklearn2pmml.preprocessing import CutTransformer
+from sklearn_pandas import DataFrameMapper
 
 from commom import f_df_to_image, f_display_images_by_side, GeneralException, f_display_title, \
     f_image_crop_white_borders
-from config import BaseConfig
 from entitys import MetricFucResultEntity, DataSplitEntity, DataFeatureEntity
-from enums import ResultCodesEnum, ConstantEnum, FileEnum
+from enums import ResultCodesEnum, ConstantEnum, FileEnum, ContextEnum
+from init import context
 from .model_base import ModelBase
 from .model_utils import f_stress_test, f_calcu_model_ks, f_get_model_score_bin, f_calcu_model_psi
 
 
+class PMMLPipelineXgb(PMMLPipeline):
+    def __init__(self, steps, ):
+        super().__init__(steps=steps)
+
+    def _filter_column_names(self, X):
+        return (numpy.asarray(X)).astype(str)
+
+    def _get_column_names(self, X):
+        if isinstance(X, DataFrame):
+            return self._filter_column_names(X.columns.values)
+        elif isinstance(X, Series):
+            return self._filter_column_names(X.name)
+        # elif isinstance(X, H2OFrame)
+        elif hasattr(X, "names"):
+            return self._filter_column_names(X.names)
+        else:
+            return None
+
+    def Xtransformer_fit(self, X, y=None):
+        # Collect feature name(s)
+        active_fields = self._get_column_names(X)
+        if active_fields is not None:
+            self.active_fields = active_fields
+        # Collect label name(s)
+        target_fields = self._get_column_names(y)
+        if target_fields is not None:
+            self.target_fields = target_fields
+
+        self.steps = list(self.steps)
+        self._validate_steps()
+
+        for (step_idx, name, transformer) in self._iter(with_final=False, filter_passthrough=False):
+            transformer.fit(X)
+            self.steps[step_idx] = (name, transformer)
+
+    def Xtransform(self, X):
+        Xt = X
+        for name, transform in self.steps[:-1]:
+            if transform is not None:
+                Xt = transform.transform(Xt)
+        return Xt
+
+    def fit(self, X, y=None, **fit_params):
+        fit_params_steps = self._check_fit_params(**fit_params)
+        Xt = self.Xtransform(X)
+        if self._final_estimator != 'passthrough':
+            fit_params_last_step = fit_params_steps[self.steps[-1][0]]
+            self._final_estimator.fit(Xt, y, **fit_params_last_step)
+        return self
+
+
 class ModelXgb(ModelBase):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         # 报告模板
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
-        self.model = None
+        self.pipeline: PMMLPipelineXgb
+        self.model = xgb.XGBClassifier
 
     def get_report_template_path(self):
         return self._template_path
@@ -37,7 +95,15 @@ class ModelXgb(ModelBase):
     def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity, *args, **kwargs):
         print(f"{'-' * 50}开始训练{'-' * 50}")
         params_xgb = self.ml_config.params_xgb
+        y_column = self._ml_config.y_column
+        str_columns = context.get(ContextEnum.XGB_COLUMNS_STR)
+        num_columns = context.get(ContextEnum.XGB_COLUMNS_NUM)
 
+        data: DataSplitEntity = kwargs["data"]
+        train_data_raw = data.train_data
+        test_data_raw = data.test_data
+
+        # xgb原生接口训练
         # dtrain = xgb.DMatrix(data=train_data.data_x, label=train_data.data_y)
         # dtest = xgb.DMatrix(data=test_data.data_x, label=test_data.data_y)
         # self.model = xgb.train(
@@ -49,6 +115,7 @@ class ModelXgb(ModelBase):
         #     verbose_eval=params_xgb.get("verbose_eval")
         # )
 
+        # xgb二次封装为sklearn接口
         self.model = xgb.XGBClassifier(objective=params_xgb.get("objective"),
                                        n_estimators=params_xgb.get("num_boost_round"),
                                        max_depth=params_xgb.get("max_depth"),
@@ -60,25 +127,43 @@ class ModelXgb(ModelBase):
                                        importance_type='weight'
                                        )
 
-        self.model.fit(X=train_data.data_x, y=train_data.data_y,
-                       eval_set=[(train_data.data_x, train_data.data_y), (test_data.data_x, test_data.data_y)],
-                       eval_metric=params_xgb.get("eval_metric"),
-                       early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
-                       verbose=params_xgb.get("verbose_eval"),
-                       )
-
-        if params_xgb.get("trees_print"):
-            trees = self.model.get_booster().get_dump()
-            for i, tree in enumerate(trees):
-                if i < self.model.best_ntree_limit:
-                    print(f"Tree {i}:")
-                    print(tree)
-
-        self._train_score = self.prob(train_data.data_x)
-        self._test_score = self.prob(test_data.data_x)
+        # self.model.fit(X=train_data.data_x, y=train_data.data_y,
+        #                eval_set=[(train_data.data_x, train_data.data_y), (test_data.data_x, test_data.data_y)],
+        #                eval_metric=params_xgb.get("eval_metric"),
+        #                early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
+        #                verbose=params_xgb.get("verbose_eval"),
+        #                )
+
+        # if params_xgb.get("trees_print"):
+        #     trees = self.model.get_booster().get_dump()
+        #     for i, tree in enumerate(trees):
+        #         if i < self.model.best_ntree_limit:
+        #             print(f"Tree {i}:")
+        #             print(tree)
+
+        mapper = [(str_columns, OneHotEncoder())]
+        # for column in str_columns:
+        #     mapper.append((column, OneHotEncoder()))
+        for column in num_columns:
+            mapper.append(
+                (column, CutTransformer([-np.inf, 10, 20, 30, +np.inf], labels=[1, 2, 3, 4])))
+        mapper = DataFrameMapper(mapper)
+
+        self.pipeline = PMMLPipelineXgb([("mapper", mapper), ("classifier", self.model)])
+        self.pipeline.Xtransformer_fit(data.data, data.data[y_column])
+        self.pipeline.fit(train_data_raw, train_data_raw[y_column],
+                          classifier__eval_set=[
+                              (self.pipeline.Xtransform(train_data_raw), train_data_raw[y_column]),
+                              (self.pipeline.Xtransform(test_data_raw), test_data_raw[y_column])
+                          ],
+                          classifier__eval_metric=params_xgb.get("eval_metric"),
+                          classifier__early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
+                          classifier__verbose=params_xgb.get("verbose_eval"),
+                          )
 
     def prob(self, x: pd.DataFrame, *args, **kwargs) -> np.array:
-        prob = self.model.predict_proba(x)[:, 1]
+        # prob = self.model.predict_proba(x)[:, 1]
+        prob = self.pipeline.predict_proba(x)[:, 1]
         return prob
 
     def score(self, x: pd.DataFrame, *args, **kwargs) -> np.array:
@@ -88,16 +173,17 @@ class ModelXgb(ModelBase):
         pass
 
     def model_save(self):
-        if self.model is None:
+        if self.pipeline is None:
             GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型不存在")
 
         path_model = self.ml_config.f_get_save_path(FileEnum.MODEL.value)
-        self.model.save_model(path_model)
+        # self.model.save_model(path_model)
+        joblib.dump(self.pipeline, path_model)
         print(f"model save to【{path_model}】success. ")
 
         path_pmml = self.ml_config.f_get_save_path(FileEnum.PMML.value)
-        pipeline = make_pmml_pipeline(self.model)
-        sklearn2pmml(pipeline, path_pmml, with_repr=True, java_home=BaseConfig.java_home)
+        # pipeline = make_pmml_pipeline(self.model)
+        sklearn2pmml(self.pipeline, path_pmml, with_repr=True, )
         print(f"model save to【{path_pmml}】success. ")
 
     def model_load(self, path: str, *args, **kwargs):
@@ -107,8 +193,9 @@ class ModelXgb(ModelBase):
         if not os.path.isfile(path_model):
             raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型文件【{path_model}】不存在")
 
-        self.model = xgb.XGBClassifier()
-        self.model.load_model(path_model)
+        # self.model = xgb.XGBClassifier()
+        # self.model.load_model(path_model)
+        self.pipeline = joblib.load(path_model)
 
         print(f"model load from【{path_model}】success.")
 
@@ -127,8 +214,8 @@ class ModelXgb(ModelBase):
             # 模型ks auc
             img_path_auc_ks = []
 
-            train_score = self._train_score
-            test_score = self._test_score
+            train_score = self.prob(train_data)
+            test_score = self.prob(test_data)
 
             train_auc, train_ks, path = _get_auc_ks(train_data[y_column], train_score, f"train")
             img_path_auc_ks.append(path)

+ 5 - 2
pipeline/pipeline.py

@@ -40,13 +40,16 @@ class Pipeline():
         train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
         test_data = self._feature_strategy.feature_generate(self._data.test_data)
         test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
-        self._model.train(train_data, test_data)
+        self._model.train(train_data, test_data, data=self._data)
         metric_model = self._model.train_report(self._data)
 
         self.metric_value_dict = {**metric_feature, **metric_model}
 
     def prob(self, data: pd.DataFrame):
-        feature = self._feature_strategy.feature_generate(data)
+        if self._ml_config.model_type == ModelEnum.XGB.value:
+            feature = data
+        else:
+            feature = self._feature_strategy.feature_generate(data)
         prob = self._model.prob(feature)
         return prob
 

+ 1 - 3
requirements-analysis.txt

@@ -21,7 +21,5 @@ pypmml==0.9.0
 #dataframe_image==0.1.14
 #thrift-sasl==0.4.3
 #pyhive==0.7.0
-#sklearn2pmml==0.103.3
+#sklearn2pmml==0.65.0
 #sklearn-pandas==2.2.0
-#dill==0.3.4
-