123456789101112131415161718192021222324 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: 模型训练管道
- """
- from entitys import DataFeatureEntity
- from model import ModelBase
- class TrainPipeline():
- def __init__(self, model: ModelBase):
- self.model = model
- def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
- self.model.train(train_data)
- self.model.predict_prob(test_data.get_Xdata())
- def generate_report(self):
- pass
- if __name__ == "__main__":
- pass
|