Source code for monai.deploy.core.models.torch_model

# Copyright 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path
from typing import Any
from zipfile import BadZipFile, ZipFile

from monai.deploy.utils.importutil import optional_import

from .model import Model

torch, _ = optional_import("torch")


[docs]class TorchScriptModel(Model): """Represents TorchScript model. TorchScript serialization format (TorchScript model file) is created by torch.jit.save() method and the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files. (https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md) We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and 'constants.pkl', and folders named 'code' and 'data'. When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode` from the serialized model file (if it is not loaded yet) and the model is ready to be used. """ model_type: str = "torch_script" @property def predictor(self) -> "torch.nn.Module": # type: ignore """Get the model's predictor (torch.nn.Module) If the predictor is not loaded, load it from the model file in evaluation mode. (https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html?highlight=eval#torch.jit.ScriptModule.eval) Returns: torch.nn.Module: the model's predictor """ if self._predictor is None: # Use a device to dynamically remap, depending on the GPU availability. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._predictor = torch.jit.load(self.path, map_location=device).eval() return self._predictor @predictor.setter def predictor(self, predictor: Any): self._predictor = predictor
[docs] def eval(self) -> "TorchScriptModel": """Set the model in evaluation model. This is a proxy method for torch.jit.ScriptModule.eval(). See https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html?highlight=eval#torch.jit.ScriptModule.eval Returns: self """ self.predictor.eval() return self
[docs] def train(self, mode: bool = True) -> "TorchScriptModel": """Set the model in training mode. This is a proxy method for torch.jit.ScriptModule.train(). See https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html?highlight=train#torch.jit.ScriptModule.train Args: mode (bool): whether the model is in training mode Returns: self """ self.predictor.train(mode) return self
[docs] @classmethod def accept(cls, path: str): prefix_code = False prefix_data = False prefix_constants_pkl = False prefix_data_pkl = False if not os.path.isfile(path): return False, None try: zip_file = ZipFile(path) for data in zip_file.filelist: file_name = data.filename pivot = file_name.find("/") if pivot != -1 and not prefix_code and file_name[pivot:].startswith("/code/"): prefix_code = True if pivot != -1 and not prefix_data and file_name[pivot:].startswith("/data/"): prefix_data = True if pivot != -1 and not prefix_constants_pkl and file_name[pivot:] == "/constants.pkl": prefix_constants_pkl = True if pivot != -1 and not prefix_data_pkl and file_name[pivot:] == "/data.pkl": prefix_data_pkl = True except BadZipFile: return False, None if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl: return True, cls.model_type return False, None