model_factory.py 596 B

12345678910111213141516171819202122232425
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/12/3
  5. @desc: 模型工厂
  6. """
  7. from typing import Type
  8. from commom import GeneralException
  9. from enums import ModelEnum, ResultCodesEnum
  10. from model import ModelBase
  11. from .model_lr import ModelLr
  12. model_map = {
  13. ModelEnum.LR.value: ModelLr
  14. }
  15. class ModelFactory():
  16. @staticmethod
  17. def get_model(model_type: str) -> Type[ModelBase]:
  18. if model_type not in model_map.keys():
  19. raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型【{model_type}】没有实现")
  20. return model_map.get(model_type)