Prechádzať zdrojové kódy

add: 模型模块化

yq 5 mesiacov pred
rodič
commit
fb1ac1de59

+ 1 - 1
data/__init__.py

@@ -2,7 +2,7 @@
 """
 @author: yq
 @time: 2024/10/30
-@desc: 
+@desc: 数据加载、加工相关
 """
 from .loader.data_loader_base import DataLoaderBase
 from .loader.data_loader_mysql import DataLoaderMysql

+ 9 - 0
data/insight/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc:  数据统计分析
+"""
+
+if __name__ == "__main__":
+    pass

+ 1 - 1
data/loader/__init__.py

@@ -2,7 +2,7 @@
 """
 @author: yq
 @time: 2024/11/1
-@desc: 
+@desc:  数据加载相关
 """
 
 if __name__ == "__main__":

+ 5 - 8
data/loader/data_loader_mysql.py

@@ -8,25 +8,22 @@ import pandas as pd
 import pymysql
 
 from commom.logger import get_logger
+from entitys import DbConfigEntity
 from .data_loader_base import DataLoaderBase
 
 logger = get_logger()
 
 
 class DataLoaderMysql(DataLoaderBase):
-    def __init__(self, host: str, port: int, user: str, passwd: str, db: str):
-        self.host = host
-        self.port = port
-        self.user = user
-        self.passwd = passwd
-        self.db = db
+    def __init__(self, db_config: DbConfigEntity):
+        self.db_config = db_config
         self.conn = None
 
     def get_connect(self):
         #  TODO 后续改成线程池
         if self.conn == None:
-            self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd,
-                                        db=self.db)
+            self.conn = pymysql.connect(host=self.db_config.host, port=self.db_config.port, user=self.db_config.user,
+                                        passwd=self.db_config.passwd, db=self.db_config.db)
         return self.conn
 
     def close_connect(self):

+ 9 - 0
data/process/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc:  数据处理
+"""
+
+if __name__ == "__main__":
+    pass

+ 13 - 0
entitys/__init__.py

@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/10/30
+@desc: 数据实体类
+"""
+from .data_feaure_entity import DataFeatureEntity
+from .db_config_entity import DbConfigEntity
+
+__all__ = ['DataFeatureEntity', 'DbConfigEntity']
+
+if __name__ == "__main__":
+    pass

+ 35 - 0
entitys/data_feaure_entity.py

@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 
+"""
+import pandas as pd
+
+
+class DataFeatureEntity():
+    def __init__(self, data: pd.DataFrame, x_columns: list, y_column: str):
+        self._data = data
+        self._x_columns = x_columns
+        self._y_column = y_column
+
+    @property
+    def data(self):
+        return self._data
+
+    @property
+    def x_columns(self):
+        return self._x_columns
+
+    @property
+    def y_column(self):
+        return self._y_column
+
+    def get_Xdata(self):
+        return self._data[self._x_columns]
+
+    def get_Ydata(self):
+        return self._data[self._y_column]
+
+if __name__ == "__main__":
+    pass

+ 37 - 0
entitys/db_config_entity.py

@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 数据库配置类
+"""
+
+class DbConfigEntity():
+    def __init__(self, host: str, port: int, user: str, passwd: str, db: str):
+        self._host = host
+        self._port = port
+        self._user = user
+        self._passwd = passwd
+        self._db = db
+
+    @property
+    def host(self):
+        return self._host
+
+    @property
+    def port(self):
+        return self._port
+
+    @property
+    def user(self):
+        return self._user
+
+    @property
+    def passwd(self):
+        return self._passwd
+
+    @property
+    def db(self):
+        return self._db
+
+if __name__ == "__main__":
+    pass

+ 9 - 0
feature/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 特征挖掘
+"""
+
+if __name__ == "__main__":
+    pass

+ 4 - 2
test.py → metric_test.py

@@ -5,10 +5,12 @@
 @desc: 
 """
 from data import DataLoaderMysql
-from metric import MetricBySqlGeneral
+from entitys import DbConfigEntity
+from metrics import MetricBySqlGeneral
 
 if __name__ == "__main__":
-    data_loader = DataLoaderMysql(host="101.126.81.2", port=18001, user="root", passwd="Cqrcb2024", db="test")
+    db_config = DbConfigEntity(host="101.126.81.2", port=18001, user="root", passwd="Cqrcb2024", db="test")
+    data_loader = DataLoaderMysql(db_config)
     metric_clzz = MetricBySqlGeneral(metric_name="auc", metric_code="auc")
     metric = metric_clzz.calculate(data_loader, "select * from test.t1")
     print(metric.head(5))

+ 0 - 0
metric/__init__.py → metrics/__init__.py


+ 0 - 0
metric/metric_base.py → metrics/metric_base.py


+ 0 - 0
metric/metric_by_sql_general.py → metrics/metric_by_sql_general.py


+ 12 - 4
model/model_base.py

@@ -4,19 +4,27 @@
 @time: 2024/1/2
 @desc: 模型基类
 """
-import pandas as pd
 import abc
 
+import pandas as pd
+
+from entitys import DataFeatureEntity
+
+
 class ModelBase(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
-    def train(self, data: pd.DataFrame):
+    def train(self, data: DataFeatureEntity, *args, **kwargs):
+        pass
+
+    @abc.abstractmethod
+    def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
         pass
 
     @abc.abstractmethod
-    def predict_prob(self, x: pd.DataFrame):
+    def predict(self, x: pd.DataFrame, *args, **kwargs):
         pass
 
     @abc.abstractmethod
-    def predict(self, x: pd.DataFrame):
+    def export_model_file(self, ):
         pass

+ 32 - 0
model/model_lr.py

@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 
+"""
+import pandas as pd
+from sklearn.linear_model import LogisticRegression
+
+from entitys import DataFeatureEntity
+from model_base import ModelBase
+
+
+class ModelLr(ModelBase):
+    def __init__(self, ):
+        self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
+
+    def train(self, data: DataFeatureEntity, *args, **kwargs):
+        self.lr.fit(data.get_Xdata(), data.get_Ydata())
+
+    def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
+        return self.lr.predict_proba(x)[:, 1]
+
+    def predict(self, x: pd.DataFrame, *args, **kwargs):
+        pass
+
+    def export_model_file(self):
+        pass
+
+
+if __name__ == "__main__":
+    pass

+ 0 - 15
model/model_route.py

@@ -1,15 +0,0 @@
-# -*- coding:utf-8 -*-
-"""
-@author: yq
-@time: 2024/1/2
-@desc: 
-"""
-from config.base_config import BaseConfig
-
-model_route_map = {"lr": ModelQwen}
-
-model_clazz = model_route_map.get(BaseConfig.classify_model_name, None)
-
-assert model_clazz is not None, f"模型【{BaseConfig.classify_model_name}】不存在"
-
-model_classify = model_clazz()

+ 4 - 0
trainer/__init__.py

@@ -5,5 +5,9 @@
 @desc: 
 """
 
+from .train import TrainPipeline
+
+__all__ = ['TrainPipeline']
+
 if __name__ == "__main__":
     pass

+ 24 - 0
trainer/train.py

@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 模型训练管道
+"""
+from entitys import DataFeatureEntity
+from model import ModelBase
+
+
+class TrainPipeline():
+    def __init__(self, model: ModelBase):
+        self.model = model
+
+    def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
+        self.model.train(train_data)
+        self.model.predict_prob(test_data.data)
+
+    def generate_report(self):
+        pass
+
+
+if __name__ == "__main__":
+    pass