浏览代码

modify: xgb代码优化

yq 4 天之前
父节点
当前提交
e458dc68b7
共有 2 个文件被更改,包括 68 次插入55 次删除
  1. 8 55
      model/model_xgb.py
  2. 60 0
      model/pipeline_xgb_util.py

+ 8 - 55
model/model_xgb.py

@@ -10,12 +10,10 @@ from os.path import dirname, realpath
 from typing import Dict
 from typing import Dict
 
 
 import joblib
 import joblib
-import numpy
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
 import scorecardpy as sc
 import scorecardpy as sc
 import xgboost as xgb
 import xgboost as xgb
-from pandas import DataFrame, Series
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.preprocessing import OneHotEncoder
 from sklearn2pmml import sklearn2pmml, PMMLPipeline
 from sklearn2pmml import sklearn2pmml, PMMLPipeline
 from sklearn2pmml.preprocessing import CutTransformer
 from sklearn2pmml.preprocessing import CutTransformer
@@ -28,65 +26,20 @@ from enums import ResultCodesEnum, ConstantEnum, FileEnum, ContextEnum
 from init import context
 from init import context
 from .model_base import ModelBase
 from .model_base import ModelBase
 from .model_utils import f_stress_test, f_calcu_model_ks, f_get_model_score_bin, f_calcu_model_psi
 from .model_utils import f_stress_test, f_calcu_model_ks, f_get_model_score_bin, f_calcu_model_psi
-
-
-class PMMLPipelineXgb(PMMLPipeline):
-    def __init__(self, steps, ):
-        super().__init__(steps=steps)
-
-    def _filter_column_names(self, X):
-        return (numpy.asarray(X)).astype(str)
-
-    def _get_column_names(self, X):
-        if isinstance(X, DataFrame):
-            return self._filter_column_names(X.columns.values)
-        elif isinstance(X, Series):
-            return self._filter_column_names(X.name)
-        # elif isinstance(X, H2OFrame)
-        elif hasattr(X, "names"):
-            return self._filter_column_names(X.names)
-        else:
-            return None
-
-    def Xtransformer_fit(self, X, y=None):
-        # Collect feature name(s)
-        active_fields = self._get_column_names(X)
-        if active_fields is not None:
-            self.active_fields = active_fields
-        # Collect label name(s)
-        target_fields = self._get_column_names(y)
-        if target_fields is not None:
-            self.target_fields = target_fields
-
-        self.steps = list(self.steps)
-        self._validate_steps()
-
-        for (step_idx, name, transformer) in self._iter(with_final=False, filter_passthrough=False):
-            transformer.fit(X)
-            self.steps[step_idx] = (name, transformer)
-
-    def Xtransform(self, X):
-        Xt = X
-        for name, transform in self.steps[:-1]:
-            if transform is not None:
-                Xt = transform.transform(Xt)
-        return Xt
-
-    def fit(self, X, y=None, **fit_params):
-        fit_params_steps = self._check_fit_params(**fit_params)
-        Xt = self.Xtransform(X)
-        if self._final_estimator != 'passthrough':
-            fit_params_last_step = fit_params_steps[self.steps[-1][0]]
-            self._final_estimator.fit(Xt, y, **fit_params_last_step)
-        return self
+from .pipeline_xgb_util import fit, Xtransform, Xtransformer_fit
 
 
 
 
 class ModelXgb(ModelBase):
 class ModelXgb(ModelBase):
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
+        # 覆写方法
+        PMMLPipeline.Xtransformer_fit = Xtransformer_fit
+        PMMLPipeline.Xtransform = Xtransform
+        PMMLPipeline.fit = fit
+
         # 报告模板
         # 报告模板
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
         self._template_path = os.path.join(dirname(dirname(realpath(__file__))), "./template/模型开发报告模板_xgb.docx")
-        self.pipeline: PMMLPipelineXgb
+        self.pipeline: PMMLPipeline
         self.model = xgb.XGBClassifier
         self.model = xgb.XGBClassifier
 
 
     def get_report_template_path(self):
     def get_report_template_path(self):
@@ -149,7 +102,7 @@ class ModelXgb(ModelBase):
                 (column, CutTransformer([-np.inf, 10, 20, 30, +np.inf], labels=[1, 2, 3, 4])))
                 (column, CutTransformer([-np.inf, 10, 20, 30, +np.inf], labels=[1, 2, 3, 4])))
         mapper = DataFrameMapper(mapper)
         mapper = DataFrameMapper(mapper)
 
 
-        self.pipeline = PMMLPipelineXgb([("mapper", mapper), ("classifier", self.model)])
+        self.pipeline = PMMLPipeline([("mapper", mapper), ("classifier", self.model)])
         self.pipeline.Xtransformer_fit(data.data, data.data[y_column])
         self.pipeline.Xtransformer_fit(data.data, data.data[y_column])
         self.pipeline.fit(train_data_raw, train_data_raw[y_column],
         self.pipeline.fit(train_data_raw, train_data_raw[y_column],
                           classifier__eval_set=[
                           classifier__eval_set=[

+ 60 - 0
model/pipeline_xgb_util.py

@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2025/4/24
+@desc: 
+"""
+
+import numpy
+from pandas import DataFrame, Series
+
+
+def _filter_column_names(X):
+    return (numpy.asarray(X)).astype(str)
+
+
+def _get_column_names(X):
+    if isinstance(X, DataFrame):
+        return _filter_column_names(X.columns.values)
+    elif isinstance(X, Series):
+        return _filter_column_names(X.name)
+    # elif isinstance(X, H2OFrame)
+    elif hasattr(X, "names"):
+        return _filter_column_names(X.names)
+    else:
+        return None
+
+
+def Xtransformer_fit(self, X, y=None):
+    # Collect feature name(s)
+    active_fields = _get_column_names(X)
+    if active_fields is not None:
+        self.active_fields = active_fields
+    # Collect label name(s)
+    target_fields = _get_column_names(y)
+    if target_fields is not None:
+        self.target_fields = target_fields
+
+    self.steps = list(self.steps)
+    self._validate_steps()
+
+    for (step_idx, name, transformer) in self._iter(with_final=False, filter_passthrough=False):
+        transformer.fit(X)
+        self.steps[step_idx] = (name, transformer)
+
+
+def Xtransform(self, X):
+    Xt = X
+    for name, transform in self.steps[:-1]:
+        if transform is not None:
+            Xt = transform.transform(Xt)
+    return Xt
+
+
+def fit(self, X, y=None, **fit_params):
+    fit_params_steps = self._check_fit_params(**fit_params)
+    Xt = self.Xtransform(X)
+    if self._final_estimator != 'passthrough':
+        fit_params_last_step = fit_params_steps[self.steps[-1][0]]
+        self._final_estimator.fit(Xt, y, **fit_params_last_step)
+    return self