train.py 520 B

123456789101112131415161718192021222324
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. from entitys import DataFeatureEntity
  8. from model import ModelBase
  9. class TrainPipeline():
  10. def __init__(self, model: ModelBase):
  11. self.model = model
  12. def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
  13. metric_train = self.model.train(train_data)
  14. self.model.predict_prob(test_data.get_Xdata())
  15. def generate_report(self):
  16. pass
  17. if __name__ == "__main__":
  18. pass