Source code for monai.deploy.core.models.factory

# Copyright 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Optional, Tuple, Type, Union

from .model import Model


[docs]class ModelFactory: """ModelFactory is a class that provides a way to create a model object."""
[docs] @staticmethod def create(path: Union[str, Path], name: str = "", model_type: str = "") -> Optional[Model]: """Creates a model object. Args: path (Union[str, Path]): A path to the model. name (str): A name of the model. model_type (str): A type of the model. Returns: A model object. Returns None if the model file/folder does not exist. """ model_type, model_cls = ModelFactory.detect_model_type(path, model_type) if model_type and model_cls: model = model_cls(str(path), name) return model else: return None
[docs] @staticmethod def detect_model_type(path: Union[str, Path], model_type: str = "") -> Tuple[str, Optional[Type[Model]]]: """Detects the model type based on a model path. Args: path (Union[str, Path]): A path to the model file/folder. model_type (str): A model type. Returns: A tuple of the model type string and the model class. """ path = Path(path) for model_cls in Model.registered_models(): # If a model_type is specified, check if it matches the model type. if model_type and model_cls.model_type != model_type: continue accept, model_type = model_cls.accept(path) if accept: return model_type, model_cls return "", None