# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Sequence, Union
import torch
from monai.inferers.utils import sliding_window_inference
from monai.utils import BlendMode, PytorchPadMode
[docs]class Inferer(ABC):
"""
A base class for model inference.
Extend this class to support operations during inference, e.g. a sliding window method.
"""
@abstractmethod
def __call__(self, inputs: torch.Tensor, network: torch.nn.Module):
"""
Run inference on `inputs` with the `network` model.
Args:
inputs: input of the model inference.
network: model for inference.
Raises:
NotImplementedError: When the subclass does not override this method.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
[docs]class SimpleInferer(Inferer):
"""
SimpleInferer is the normal inference method that run model forward() directly.
"""
def __init__(self) -> None:
Inferer.__init__(self)
def __call__(self, inputs: torch.Tensor, network: torch.nn.Module):
"""Unified callable function API of Inferers.
Args:
inputs: model input data for inference.
network: target model to execute inference.
"""
return network(inputs)
[docs]class SlidingWindowInferer(Inferer):
"""
Sliding window method for model inference,
with `sw_batch_size` windows for every model.forward().
Args:
roi_size: the window size to execute SlidingWindow evaluation.
If it has non-positive components, the corresponding `inputs` size will be used.
if the components of the `roi_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
sw_batch_size: the batch size to run window slices.
overlap: Amount of overlap between scans.
mode: {``"constant"``, ``"gaussian"``}
How to blend output of overlapping windows. Defaults to ``"constant"``.
- ``"constant``": gives equal weight to all predictions.
- ``"gaussian``": gives less weight to predictions on edges of windows.
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
spatial dimensions.
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
See also: https://pytorch.org/docs/stable/nn.functional.html#pad
cval: fill value for 'constant' padding mode. Default: 0
Note:
``sw_batch_size`` denotes the max number of windows per network inference iteration,
not the batch size of inputs.
"""
def __init__(
self,
roi_size: Union[Sequence[int], int],
sw_batch_size: int = 1,
overlap: float = 0.25,
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
sigma_scale: Union[Sequence[float], float] = 0.125,
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
cval: float = 0.0,
) -> None:
Inferer.__init__(self)
self.roi_size = roi_size
self.sw_batch_size = sw_batch_size
self.overlap = overlap
self.mode: BlendMode = BlendMode(mode)
self.sigma_scale = sigma_scale
self.padding_mode = padding_mode
self.cval = cval
def __call__(self, inputs: torch.Tensor, network: torch.nn.Module) -> torch.Tensor:
"""
Unified callable function API of Inferers.
Args:
inputs: model input data for inference.
network: target model to execute inference.
"""
return sliding_window_inference(
inputs=inputs,
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
predictor=network,
overlap=self.overlap,
mode=self.mode,
sigma_scale=self.sigma_scale,
padding_mode=self.padding_mode,
cval=self.cval,
)