train.py 648 B

1234567891011121314151617181920212223242526
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. from entitys import DataPreparedEntity, TrainConfigEntity
  8. from model import f_get_model
  9. class TrainPipeline():
  10. def __init__(self, train_config: TrainConfigEntity):
  11. self._train_config = train_config
  12. model_clazz = f_get_model(self._train_config.model_type)
  13. self.model = model_clazz(self._train_config)
  14. def train(self, data: DataPreparedEntity):
  15. metric_train = self.model.train(data)
  16. print(metric_train)
  17. def generate_report(self, data: DataPreparedEntity):
  18. pass
  19. if __name__ == "__main__":
  20. pass