Source code for monai.deploy.core.models.torch_model

import os.path
from typing import Any
from zipfile import BadZipFile, ZipFile

from monai.deploy.utils.importutil import optional_import

from .model import Model

torch, _ = optional_import("torch")

[docs]class TorchScriptModel(Model): """Represents TorchScript model. TorchScript serialization format (TorchScript model file) is created by method and the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files. ( We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and 'constants.pkl', and folders named 'code' and 'data'. When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode` from the serialized model file (if it is not loaded yet) and the model is ready to be used. """ model_type: str = "torch_script" @property def predictor(self) -> "torch.nn.Module": # type: ignore """Get the model's predictor (torch.nn.Module) If the predictor is not loaded, load it from the model file in evaluation mode. ( Returns: torch.nn.Module: the model's predictor """ if self._predictor is None: # Use a device to dynamically remap, depending on the GPU availability. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._predictor = torch.jit.load(self.path, map_location=device).eval() return self._predictor @predictor.setter def predictor(self, predictor: Any): self._predictor = predictor
[docs] def eval(self) -> "TorchScriptModel": """Set the model in evaluation model. This is a proxy method for torch.jit.ScriptModule.eval(). See Returns: self """ self.predictor.eval() return self
[docs] def train(self, mode: bool = True) -> "TorchScriptModel": """Set the model in training mode. This is a proxy method for torch.jit.ScriptModule.train(). See Args: mode (bool): whether the model is in training mode Returns: self """ self.predictor.train(mode) return self
[docs] @classmethod def accept(cls, path: str): prefix_code = False prefix_data = False prefix_constants_pkl = False prefix_data_pkl = False if not os.path.isfile(path): return False, None try: zip_file = ZipFile(path) for data in zip_file.filelist: file_name = data.filename pivot = file_name.find("/") if pivot != -1 and not prefix_code and file_name[pivot:].startswith("/code/"): prefix_code = True if pivot != -1 and not prefix_data and file_name[pivot:].startswith("/data/"): prefix_data = True if pivot != -1 and not prefix_constants_pkl and file_name[pivot:] == "/constants.pkl": prefix_constants_pkl = True if pivot != -1 and not prefix_data_pkl and file_name[pivot:] == "/data.pkl": prefix_data_pkl = True except BadZipFile: return False, None if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl: return True, cls.model_type return False, None