Browse Source

add: xgb模型加载及norm特征加载

yq 17 hours ago
parent
commit
6ef9509dce
3 changed files with 26 additions and 8 deletions
  1. 1 0
      enums/file_enum.py
  2. 24 8
      feature/bin/strategy_norm.py
  3. 1 0
      model/model_xgb.py

+ 1 - 0
enums/file_enum.py

@@ -11,6 +11,7 @@ class FileEnum(Enum):
     MLCFG = "mlcfg.json"
     OLCFG = "olcfg.json"
     FEATURE = "feature.csv"
+    FEATURE_PKL = "feature.pkl"
     CARD = "card.csv"
     CARD_CFG = "card.cfg"
     COEF = "coef.json"

+ 24 - 8
feature/bin/strategy_norm.py

@@ -4,9 +4,10 @@
 @time: 2025/4/3
 @desc: 值标准化,类似于分箱
 """
-
+import os
 from typing import Dict, List
 
+import joblib
 import pandas as pd
 import xgboost as xgb
 from pandas.core.dtypes.common import is_numeric_dtype
@@ -14,7 +15,7 @@ from pandas.core.dtypes.common import is_numeric_dtype
 from commom import GeneralException, f_display_title
 from data import DataExplore
 from entitys import DataSplitEntity, MetricFucResultEntity
-from enums import ResultCodesEnum, ContextEnum
+from enums import ResultCodesEnum, ContextEnum, FileEnum
 from feature.feature_strategy_base import FeatureStrategyBase
 from init import context
 from .utils import f_format_value, OneHot, f_format_bin
@@ -158,14 +159,29 @@ class StrategyNorm(FeatureStrategyBase):
         return df[model_columns]
 
     def feature_save(self, *args, **kwargs):
-        self.x_columns = None
-        self.one_hot_encoder_dict: Dict[str, OneHot] = {}
-        self.points_dict: Dict[str, List[float]] = {}
-
-        pass
+        if self.x_columns is None:
+            GeneralException(ResultCodesEnum.NOT_FOUND, message=f"feature不存在")
+
+        path = self.ml_config.f_get_save_path(FileEnum.FEATURE_PKL.value)
+        feature_info = {
+            "x_columns": self.x_columns,
+            "one_hot_encoder_dict": self.one_hot_encoder_dict,
+            "points_dict": self.points_dict,
+        }
+        joblib.dump(feature_info, path)
+        print(f"feature save to【{path}】success. ")
 
     def feature_load(self, path: str, *args, **kwargs):
-        pass
+        if os.path.isdir(path):
+            path = os.path.join(path, FileEnum.FEATURE_PKL.value)
+        if not os.path.isfile(path) or FileEnum.FEATURE_PKL.value not in path:
+            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"特征信息【{FileEnum.FEATURE_PKL.value}】不存在")
+
+        feature_info = joblib.load(path)
+        self.x_columns = feature_info["x_columns"]
+        self.one_hot_encoder_dict = feature_info["one_hot_encoder_dict"]
+        self.points_dict = feature_info["points_dict"]
+        print(f"feature load from【{path}】success.")
 
     def feature_report(self, data: DataSplitEntity, *args, **kwargs) -> Dict[str, MetricFucResultEntity]:
 

+ 1 - 0
model/model_xgb.py

@@ -98,6 +98,7 @@ class ModelXgb(ModelBase):
         path_pmml = self.ml_config.f_get_save_path(FileEnum.PMML.value)
         pipeline = make_pmml_pipeline(self.model)
         sklearn2pmml(pipeline, path_pmml, with_repr=True, java_home=BaseConfig.java_home)
+        print(f"model save to【{path_pmml}】success. ")
 
     def model_load(self, path: str, *args, **kwargs):
         if not os.path.isdir(path):