# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: 模型训练管道 """ from entitys import DataFeatureEntity from model import ModelBase class TrainPipeline(): def __init__(self, model: ModelBase): self.model = model def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity): metric_train = self.model.train(train_data) self.model.predict_prob(test_data.get_Xdata()) def generate_report(self): pass if __name__ == "__main__": pass