浏览代码

modify: 代码优化

yq 1 月之前
父节点
当前提交
f2e4eb738c
共有 2 个文件被更改,包括 10 次插入3 次删除
  1. 6 1
      feature/bin/strategy_norm.py
  2. 4 2
      feature/bin/utils.py

+ 6 - 1
feature/bin/strategy_norm.py

@@ -70,7 +70,8 @@ class StrategyNorm(FeatureStrategyBase):
                     test_data[x_column] = test_data[x_column].apply(lambda x: f_format_value(points, x))
                     test_data[x_column] = test_data[x_column].apply(lambda x: f_format_value(points, x))
             else:
             else:
                 str_columns.append(x_column)
                 str_columns.append(x_column)
-                one_hot_encoder = OneHot(data.data, x_column)
+                one_hot_encoder = OneHot()
+                one_hot_encoder.fit(data.data, x_column)
                 one_hot_encoder.encoder(train_data)
                 one_hot_encoder.encoder(train_data)
                 one_hot_encoder.encoder(test_data)
                 one_hot_encoder.encoder(test_data)
                 model_columns.extend(one_hot_encoder.columns_onehot)
                 model_columns.extend(one_hot_encoder.columns_onehot)
@@ -157,6 +158,10 @@ class StrategyNorm(FeatureStrategyBase):
         return df[model_columns]
         return df[model_columns]
 
 
     def feature_save(self, *args, **kwargs):
     def feature_save(self, *args, **kwargs):
+        self.x_columns = None
+        self.one_hot_encoder_dict: Dict[str, OneHot] = {}
+        self.points_dict: Dict[str, List[float]] = {}
+
         pass
         pass
 
 
     def feature_load(self, path: str, *args, **kwargs):
     def feature_load(self, path: str, *args, **kwargs):

+ 4 - 2
feature/bin/utils.py

@@ -76,9 +76,11 @@ def f_format_value(points, raw_v):
 
 
 class OneHot():
 class OneHot():
 
 
-    def __init__(self, data: pd.DataFrame, x_column: str):
-        self._x_column = x_column
+    def __init__(self, ):
         self._one_hot_encoder = OneHotEncoder()
         self._one_hot_encoder = OneHotEncoder()
+
+    def fit(self, data: pd.DataFrame, x_column: str):
+        self._x_column = x_column
         self._one_hot_encoder.fit(data[x_column].to_numpy().reshape(-1, 1))
         self._one_hot_encoder.fit(data[x_column].to_numpy().reshape(-1, 1))
         self._columns_onehot = [re.sub(r"[\[\]<]", "", f"{x_column}({i})") for i in
         self._columns_onehot = [re.sub(r"[\[\]<]", "", f"{x_column}({i})") for i in
                                 self._one_hot_encoder.categories_[0]]
                                 self._one_hot_encoder.categories_[0]]