Source code for monai.inferers.inferer

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Sequence, Union

import torch
import torch.nn as nn

from monai.inferers.utils import sliding_window_inference
from monai.utils import BlendMode, PytorchPadMode
from monai.visualize import CAM, GradCAM, GradCAMpp

__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer"]


[docs]class Inferer(ABC): """ A base class for model inference. Extend this class to support operations during inference, e.g. a sliding window method. Example code:: device = torch.device("cuda:0") data = ToTensor()(LoadImage()(filename=img_path)).to(device) model = UNet(...).to(device) inferer = SlidingWindowInferer(...) model.eval() with torch.no_grad(): pred = inferer(inputs=data, network=model) ... """
[docs] @abstractmethod def __call__( self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any, ): """ Run inference on `inputs` with the `network` model. Args: inputs: input of the model inference. network: model for inference. args: optional args to be passed to ``network``. kwargs: optional keyword args to be passed to ``network``. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
[docs]class SimpleInferer(Inferer): """ SimpleInferer is the normal inference method that run model forward() directly. Usage example can be found in the :py:class:`monai.inferers.Inferer` base class. """ def __init__(self) -> None: Inferer.__init__(self)
[docs] def __call__( self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any, ): """Unified callable function API of Inferers. Args: inputs: model input data for inference. network: target model to execute inference. supports callables such as ``lambda x: my_torch_model(x, additional_config)`` args: optional args to be passed to ``network``. kwargs: optional keyword args to be passed to ``network``. """ return network(inputs, *args, **kwargs)
[docs]class SlidingWindowInferer(Inferer): """ Sliding window method for model inference, with `sw_batch_size` windows for every model.forward(). Usage example can be found in the :py:class:`monai.inferers.Inferer` base class. Args: roi_size: the window size to execute SlidingWindow evaluation. If it has non-positive components, the corresponding `inputs` size will be used. if the components of the `roi_size` are non-positive values, the transform will use the corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. overlap: Amount of overlap between scans. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. - ``"constant``": gives equal weight to all predictions. - ``"gaussian``": gives less weight to predictions on edges of windows. sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding spatial dimensions. padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` See also: https://pytorch.org/docs/stable/nn.functional.html#pad cval: fill value for 'constant' padding mode. Default: 0 sw_device: device for the window data. By default the device (and accordingly the memory) of the `inputs` is used. Normally `sw_device` should be consistent with the device where `predictor` is defined. device: device for the stitched output prediction. By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, not the batch size of inputs. """ def __init__( self, roi_size: Union[Sequence[int], int], sw_batch_size: int = 1, overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT, sigma_scale: Union[Sequence[float], float] = 0.125, padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, ) -> None: Inferer.__init__(self) self.roi_size = roi_size self.sw_batch_size = sw_batch_size self.overlap = overlap self.mode: BlendMode = BlendMode(mode) self.sigma_scale = sigma_scale self.padding_mode = padding_mode self.cval = cval self.sw_device = sw_device self.device = device
[docs] def __call__( self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any, ) -> torch.Tensor: """ Args: inputs: model input data for inference. network: target model to execute inference. supports callables such as ``lambda x: my_torch_model(x, additional_config)`` args: optional args to be passed to ``network``. kwargs: optional keyword args to be passed to ``network``. """ return sliding_window_inference( inputs, self.roi_size, self.sw_batch_size, network, self.overlap, self.mode, self.sigma_scale, self.padding_mode, self.cval, self.sw_device, self.device, *args, **kwargs, )
[docs]class SaliencyInferer(Inferer): """ SaliencyInferer is inference with activation maps. Args: cam_name: expected CAM method name, should be: "CAM", "GradCAM" or "GradCAMpp". target_layers: name of the model layer to generate the feature map. class_idx: index of the class to be visualized. if None, default to argmax(logits). args: other optional args to be passed to the `__init__` of cam. kwargs: other optional keyword args to be passed to `__init__` of cam. """ def __init__(self, cam_name: str, target_layers: str, class_idx: Optional[int] = None, *args, **kwargs) -> None: Inferer.__init__(self) if cam_name.lower() not in ("cam", "gradcam", "gradcampp"): raise ValueError("cam_name should be: 'CAM', 'GradCAM' or 'GradCAMpp'.") self.cam_name = cam_name.lower() self.target_layers = target_layers self.class_idx = class_idx self.args = args self.kwargs = kwargs
[docs] def __call__( # type: ignore self, inputs: torch.Tensor, network: nn.Module, *args: Any, **kwargs: Any, ): """Unified callable function API of Inferers. Args: inputs: model input data for inference. network: target model to execute inference. supports callables such as ``lambda x: my_torch_model(x, additional_config)`` args: other optional args to be passed to the `__call__` of cam. kwargs: other optional keyword args to be passed to `__call__` of cam. """ cam: Union[CAM, GradCAM, GradCAMpp] if self.cam_name == "cam": cam = CAM(network, self.target_layers, *self.args, **self.kwargs) elif self.cam_name == "gradcam": cam = GradCAM(network, self.target_layers, *self.args, **self.kwargs) else: cam = GradCAMpp(network, self.target_layers, *self.args, **self.kwargs) return cam(inputs, self.class_idx, *args, **kwargs)