12345678910111213141516171819202122232425 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/12/3
- @desc: 模型工厂
- """
- from typing import Type
- from commom import GeneralException
- from enums import ModelEnum, ResultCodesEnum
- from model import ModelBase
- from .model_lr import ModelLr
- model_map = {
- ModelEnum.LR.value: ModelLr
- }
- class ModelFactory():
- @staticmethod
- def get_model(model_type: str) -> Type[ModelBase]:
- if model_type not in model_map.keys():
- raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型【{model_type}】没有实现")
- return model_map.get(model_type)
|