Jelajahi Sumber

add: MetricTrainEntity

yq 5 bulan lalu
induk
melakukan
0fe6f8e038
5 mengubah file dengan 30 tambahan dan 6 penghapusan
  1. 2 1
      entitys/__init__.py
  2. 22 0
      entitys/metric_entity.py
  3. 2 2
      model/model_base.py
  4. 3 2
      model/model_lr.py
  5. 1 1
      trainer/train.py

+ 2 - 1
entitys/__init__.py

@@ -6,8 +6,9 @@
 """
 from .data_feaure_entity import DataFeatureEntity
 from .db_config_entity import DbConfigEntity
+from .metric_entity import MetricTrainEntity
 
-__all__ = ['DataFeatureEntity', 'DbConfigEntity']
+__all__ = ['DataFeatureEntity', 'DbConfigEntity', 'MetricTrainEntity']
 
 if __name__ == "__main__":
     pass

+ 22 - 0
entitys/metric_entity.py

@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 
+"""
+
+class MetricTrainEntity():
+    def __init__(self, auc: float, ks: float):
+        self._auc = auc
+        self._ks = ks
+
+    @property
+    def auc(self):
+        return self._auc
+
+    @property
+    def ks(self):
+        return self._ks
+
+if __name__ == "__main__":
+    pass

+ 2 - 2
model/model_base.py

@@ -8,13 +8,13 @@ import abc
 
 import pandas as pd
 
-from entitys import DataFeatureEntity
+from entitys import DataFeatureEntity, MetricTrainEntity
 
 
 class ModelBase(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
-    def train(self, data: DataFeatureEntity, *args, **kwargs):
+    def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
         pass
 
     @abc.abstractmethod

+ 3 - 2
model/model_lr.py

@@ -7,7 +7,7 @@
 import pandas as pd
 from sklearn.linear_model import LogisticRegression
 
-from entitys import DataFeatureEntity
+from entitys import DataFeatureEntity, MetricTrainEntity
 from model_base import ModelBase
 
 
@@ -15,8 +15,9 @@ class ModelLr(ModelBase):
     def __init__(self, ):
         self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
 
-    def train(self, data: DataFeatureEntity, *args, **kwargs):
+    def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
         self.lr.fit(data.get_Xdata(), data.get_Ydata())
+        return MetricTrainEntity(0.7, 0.4)
 
     def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
         return self.lr.predict_proba(x)[:, 1]

+ 1 - 1
trainer/train.py

@@ -13,7 +13,7 @@ class TrainPipeline():
         self.model = model
 
     def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
-        self.model.train(train_data)
+        metric_train = self.model.train(train_data)
         self.model.predict_prob(test_data.get_Xdata())
 
     def generate_report(self):