# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/1
@desc: 模型训练管道
"""
from entitys import DataFeatureEntity
from model import ModelBase


class TrainPipeline():
    def __init__(self, model: ModelBase):
        self.model = model

    def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
        metric_train = self.model.train(train_data)
        self.model.predict_prob(test_data.get_Xdata())

    def generate_report(self):
        pass


if __name__ == "__main__":
    pass