Inference methods#

Inferers#

class monai.inferers.Inferer[source]#

A base class for model inference. Extend this class to support operations during inference, e.g. a sliding window method.

Example code:

device = torch.device("cuda:0")
transform = Compose([ToTensor(), LoadImage(image_only=True)])
data = transform(img_path).to(device)
model = UNet(...).to(device)
inferer = SlidingWindowInferer(...)

model.eval()
with torch.no_grad():
    pred = inferer(inputs=data, network=model)
...
abstract __call__(inputs, network, *args, **kwargs)[source]#

Run inference on inputs with the network model.

Parameters:
  • inputs (Tensor) – input of the model inference.

  • network (Callable) – model for inference.

  • args (Any) – optional args to be passed to network.

  • kwargs (Any) – optional keyword args to be passed to network.

Raises:

NotImplementedError – When the subclass does not override this method.

Return type:

Any

PatchInferer#

class monai.inferers.PatchInferer(splitter=None, merger_cls=<class 'monai.inferers.merger.AvgMerger'>, batch_size=1, preprocessing=None, postprocessing=None, output_keys=None, match_spatial_shape=True, buffer_size=0, **merger_kwargs)[source]#

Inference on patches instead of the whole image based on Splitter and Merger. This splits the input image into patches and then merge the resulted patches.

Parameters:
  • splitter – a Splitter object that split the inputs into patches. Defaults to None. If not provided or None, the inputs are considered to be already split into patches. In this case, the output merged_shape and the optional cropped_shape cannot be inferred and should be explicitly provided.

  • merger_cls – a Merger subclass that can be instantiated to merges patch outputs. It can also be a string that matches the name of a class inherited from Merger class. Defaults to AvgMerger.

  • batch_size – batch size for patches. If the input tensor is already batched [BxCxWxH], this adds additional batching [(Bp*B)xCxWpxHp] for inference on patches. Defaults to 1.

  • preprocessing – a callable that process patches before the being fed to the network. Defaults to None.

  • postprocessing – a callable that process the output of the network. Defaults to None.

  • output_keys – if the network output is a dictionary, this defines the keys of the output dictionary to be used for merging. Defaults to None, where all the keys are used.

  • match_spatial_shape – whether to crop the output to match the input shape. Defaults to True.

  • buffer_size – number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0.

  • merger_kwargs – arguments to be passed to merger_cls for instantiation. merged_shape is calculated automatically based on the input shape and the output patch shape unless it is passed here.

__call__(inputs, network, *args, **kwargs)[source]#
Parameters:
  • inputs – input data for inference, a torch.Tensor, representing an image or batch of images. However if the data is already split, it can be fed by providing a list of tuple (patch, location), or a MetaTensor that has metadata for PatchKeys.LOCATION. In both cases no splitter should be provided.

  • network – target model to execute inference. supports callables such as lambda x: my_torch_model(x, additional_config)

  • args – optional args to be passed to network.

  • kwargs – optional keyword args to be passed to network.

SimpleInferer#

class monai.inferers.SimpleInferer[source]#

SimpleInferer is the normal inference method that run model forward() directly. Usage example can be found in the monai.inferers.Inferer base class.

__call__(inputs, network, *args, **kwargs)[source]#

Unified callable function API of Inferers.

Parameters:
  • inputs (Tensor) – model input data for inference.

  • network (Callable[…, Tensor]) – target model to execute inference. supports callables such as lambda x: my_torch_model(x, additional_config)

  • args (Any) – optional args to be passed to network.

  • kwargs (Any) – optional keyword args to be passed to network.

Return type:

Tensor

SlidingWindowInferer#

class monai.inferers.SlidingWindowInferer(roi_size, sw_batch_size=1, overlap=0.25, mode=BlendMode.CONSTANT, sigma_scale=0.125, padding_mode=PytorchPadMode.CONSTANT, cval=0.0, sw_device=None, device=None, progress=False, cache_roi_weight_map=False, cpu_thresh=None, buffer_steps=None, buffer_dim=-1, with_coord=False)[source]#

Sliding window method for model inference, with sw_batch_size windows for every model.forward(). Usage example can be found in the monai.inferers.Inferer base class.

Parameters:
  • roi_size – the window size to execute SlidingWindow evaluation. If it has non-positive components, the corresponding inputs size will be used. if the components of the roi_size are non-positive values, the transform will use the corresponding components of img size. For example, roi_size=(32, -1) will be adapted to (32, 64) if the second spatial dimension size of img is 64.

  • sw_batch_size – the batch size to run window slices.

  • overlap – Amount of overlap between scans along each spatial dimension, defaults to 0.25.

  • mode

    {"constant", "gaussian"} How to blend output of overlapping windows. Defaults to "constant".

    • "constant”: gives equal weight to all predictions.

    • "gaussian”: gives less weight to predictions on edges of windows.

  • sigma_scale – the standard deviation coefficient of the Gaussian window when mode is "gaussian". Default: 0.125. Actual window sigma is sigma_scale * dim_size. When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding spatial dimensions.

  • padding_mode – {"constant", "reflect", "replicate", "circular"} Padding mode when roi_size is larger than inputs. Defaults to "constant" See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

  • cval – fill value for ‘constant’ padding mode. Default: 0

  • sw_device – device for the window data. By default the device (and accordingly the memory) of the inputs is used. Normally sw_device should be consistent with the device where predictor is defined.

  • device – device for the stitched output prediction. By default the device (and accordingly the memory) of the inputs is used. If for example set to device=torch.device(‘cpu’) the gpu memory consumption is less and independent of the inputs and roi_size. Output is on the device.

  • progress – whether to print a tqdm progress bar.

  • cache_roi_weight_map – whether to precompute the ROI weight map.

  • cpu_thresh – when provided, dynamically switch to stitching on cpu (to save gpu memory) when input image volume is larger than this threshold (in pixels/voxels). Otherwise use "device". Thus, the output may end-up on either cpu or gpu.

  • buffer_steps – the number of sliding window iterations along the buffer_dim to be buffered on sw_device before writing to device. (Typically, sw_device is cuda and device is cpu.) default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.

  • buffer_dim – the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension.

  • with_coord – whether to pass the window coordinates to network. Defaults to False. If True, the network’s 2nd input argument should accept the window coordinates.

Note

sw_batch_size denotes the max number of windows per network inference iteration, not the batch size of inputs.

__call__(inputs, network, *args, **kwargs)[source]#
Parameters:
  • inputs – model input data for inference.

  • network – target model to execute inference. supports callables such as lambda x: my_torch_model(x, additional_config)

  • args – optional args to be passed to network.

  • kwargs – optional keyword args to be passed to network.

SlidingWindowInfererAdapt#

class monai.inferers.SlidingWindowInfererAdapt(roi_size, sw_batch_size=1, overlap=0.25, mode=BlendMode.CONSTANT, sigma_scale=0.125, padding_mode=PytorchPadMode.CONSTANT, cval=0.0, sw_device=None, device=None, progress=False, cache_roi_weight_map=False, cpu_thresh=None, buffer_steps=None, buffer_dim=-1, with_coord=False)[source]#

SlidingWindowInfererAdapt extends SlidingWindowInferer to automatically switch to buffered and then to CPU stitching, when OOM on GPU. It also records a size of such large images to automatically try CPU stitching for the next large image of a similar size. If the stitching ‘device’ input parameter is provided, automatic adaptation won’t be attempted, please keep the default option device = None for adaptive behavior. Note: the output might be on CPU (even if the input was on GPU), if the GPU memory was not sufficient.

__call__(inputs, network, *args, **kwargs)[source]#
Parameters:
  • inputs – model input data for inference.

  • network – target model to execute inference. supports callables such as lambda x: my_torch_model(x, additional_config)

  • args – optional args to be passed to network.

  • kwargs – optional keyword args to be passed to network.

SaliencyInferer#

class monai.inferers.SaliencyInferer(cam_name, target_layers, class_idx=None, *args, **kwargs)[source]#

SaliencyInferer is inference with activation maps.

Parameters:
  • cam_name – expected CAM method name, should be: “CAM”, “GradCAM” or “GradCAMpp”.

  • target_layers – name of the model layer to generate the feature map.

  • class_idx – index of the class to be visualized. if None, default to argmax(logits).

  • args – other optional args to be passed to the __init__ of cam.

  • kwargs – other optional keyword args to be passed to __init__ of cam.

__call__(inputs, network, *args, **kwargs)[source]#

Unified callable function API of Inferers.

Parameters:
  • inputs (Tensor) – model input data for inference.

  • network (Module) – target model to execute inference. supports callables such as lambda x: my_torch_model(x, additional_config)

  • args (Any) – other optional args to be passed to the __call__ of cam.

  • kwargs (Any) – other optional keyword args to be passed to __call__ of cam.

SliceInferer#

class monai.inferers.SliceInferer(spatial_dim=0, *args, **kwargs)[source]#

SliceInferer extends SlidingWindowInferer to provide slice-by-slice (2D) inference when provided a 3D volume. A typical use case could be a 2D model (like 2D segmentation UNet) operates on the slices from a 3D volume, and the output is a 3D volume with 2D slices aggregated. Example:

# sliding over the `spatial_dim`
inferer = SliceInferer(roi_size=(64, 256), sw_batch_size=1, spatial_dim=1)
output = inferer(input_volume, net)
Parameters:
  • spatial_dim (int) – Spatial dimension over which the slice-by-slice inference runs on the 3D volume. For example 0 could slide over axial slices. 1 over coronal slices and 2 over sagittal slices.

  • args (Any) – other optional args to be passed to the __init__ of base class SlidingWindowInferer.

  • kwargs (Any) – other optional keyword args to be passed to __init__ of base class SlidingWindowInferer.

Note

roi_size in SliceInferer is expected to be a 2D tuple when a 3D volume is provided. This allows sliding across slices along the 3D volume using a selected spatial_dim.

__call__(inputs, network, *args, **kwargs)[source]#
Parameters:
  • inputs – 3D input for inference

  • network – 2D model to execute inference on slices in the 3D input

  • args – optional args to be passed to network.

  • kwargs – optional keyword args to be passed to network.

network_wrapper(network, x, *args, **kwargs)[source]#

Wrapper handles inference for 2D models over 3D volume inputs.

Splitters#

class monai.inferers.Splitter(patch_size, device=None)[source]#

A base class for splitting the inputs into iterable tuple of patches and locations Extend this class to support operations for PatchInference, e.g. SlidingPatchSplitter.

Parameters:
  • patch_size – the size of patches to be generated.

  • device – the device where the patches are generated.

abstract __call__(inputs)[source]#

Split the input image (or batch of images) into patches and return pairs of (patch, location). Where location is the coordinate of top left [front] corner of a patch.

Parameters:

inputs (Any) – either a tensor of shape BCHW[D], representing a batch of images, or a filename (str) or list of filenames to the image(s).

Raises:

NotImplementedError – When the subclass does not override this method.

Return type:

Iterable[tuple[Tensor, Sequence[int]]]

abstract get_input_shape(inputs)[source]#

Return the input spatial shape.

Parameters:

inputs (Any) – either a tensor of shape BCHW[D], representing a batch of images, or a filename (str) or list of filenames to the image(s).

Raises:

NotImplementedError – When the subclass does not override this method.

Return type:

tuple

abstract get_padded_shape(inputs)[source]#

Return the actual spatial shape covered by the output split patches. For instance, if the input image is padded, the actual spatial shape will be enlarged and not the same as input spatial shape.

Parameters:

inputs (Any) – either a tensor of shape BCHW[D], representing a batch of images, or a filename (str) or list of filenames to the image(s).

Raises:

NotImplementedError – When the subclass does not override this method.

Return type:

tuple

SlidingWindowSplitter#

class monai.inferers.SlidingWindowSplitter(patch_size, overlap=0.0, offset=0, filter_fn=None, pad_mode='constant', pad_value=0, device=None)[source]#

Splits the input into patches with sliding window strategy and a possible overlap. It also allows offsetting the starting position and filtering the patches.

Parameters:
  • patch_size – the size of the patches to be generated.

  • offset – the amount of offset for the patches with respect to the original input. Defaults to 0.

  • overlap – the amount of overlap between patches in each dimension. It can be either a float in the range of [0.0, 1.0) that defines relative overlap to the patch size, or it can be a non-negative int that defines number of pixels for overlap. Defaults to 0.0.

  • filter_fn – a callable to filter patches. It should accepts exactly two parameters (patch, location), and return True for a patch to keep. Defaults to no filtering.

  • pad_mode – string define the mode for torch.nn.functional.pad. The acceptable values are “constant”, “reflect”, “replicate”, “circular” or None. Default to “constant”. If None, no padding will be applied, so it will drop the patches crossing the border of the image (either when the offset is negative or the image is non-divisible by the patch_size).

  • pad_value – the value for “constant” padding. Defaults to 0.

  • device – the device where the patches are generated. Defaults to the device of inputs.

Note

When a scaler value is provided for patch_size, offset, or overlap,

it is broadcasted to all the spatial dimensions.

__call__(inputs)[source]#

Split the input tensor into patches and return patches and locations.

Parameters:

inputs (Any) – either a torch.Tensor with BCHW[D] dimensions, representing an image or a batch of images

Yields:

tuple[torch.Tensor, Sequence[int]] – yields tuple of patch and location

Return type:

Iterable[tuple[Tensor, Sequence[int]]]

get_input_shape(inputs)[source]#

Return the input spatial shape.

Parameters:

inputs (Any) – either a tensor of shape BCHW[D], representing a batch of images, or a filename (str) or list of filenames to the image(s).

Return type:

tuple

Returns:

spatial_shape

get_padded_shape(inputs)[source]#

Return the actual spatial shape covered by the output split patches. For instance, if the input image is padded, the actual spatial shape will be enlarged and not the same as input spatial shape.

Parameters:

inputs (Any) – either a tensor of shape BCHW[D], representing a batch of images, or a filename (str) or list of filenames to the image(s).

Return type:

tuple

Returns:

padded_spatial_shape

WSISlidingWindowSplitter#

class monai.inferers.WSISlidingWindowSplitter(patch_size, overlap=0.0, offset=0, filter_fn=None, pad_mode='constant', device=None, reader='OpenSlide', **reader_kwargs)[source]#

Splits the whole slide image input into patches with sliding window strategy and a possible overlap. This extracts patches from file without loading the entire slide into memory. It also allows offsetting the starting position and filtering the patches.

Parameters:
  • patch_size – the size of the patches to be generated.

  • offset – the amount of offset for the patches with respect to the original input. Defaults to 0.

  • overlap – the amount of overlap between patches in each dimension. It can be either a float in the range of [0.0, 1.0) that defines relative overlap to the patch size, or it can be a non-negative int that defines number of pixels for overlap. Defaults to 0.0.

  • filter_fn – a callable to filter patches. It should accepts exactly two parameters (patch, location), and return True for a patch to keep. Defaults to no filtering.

  • pad_mode – define the mode for padding. Either “constant” or None. Default to “constant”. Padding is only supported with “OpenSlide” or “cuCIM” backend, and the filling value is 256.

  • device – the device where the patches are generated. Defaults to the device of inputs.

  • reader

    the module to be used for loading whole slide imaging. If reader is

    • a string, it defines the backend of monai.data.WSIReader. Defaults to “OpenSlide”.

    • a class (inherited from BaseWSIReader), it is initialized and set as wsi_reader.

    • an instance of a class inherited from BaseWSIReader, it is set as the wsi_reader.

    To obtain an optimized performance please use either “cuCIM” or “OpenSlide” backend.

  • reader_kwargs – the arguments to pass to WSIReader or the provided whole slide reader class. For instance, level=2, dtype=torch.float32, etc. Note that if level is not provided, level=0 is assumed.

Note

When a scaler value is provided for patch_size, offset, or overlap, it is broadcasted to all the spatial dimensions.

__call__(inputs)[source]#

Split the input tensor into patches and return patches and locations.

Parameters:

inputs – the file path to a whole slide image.

Yields:

tuple[torch.Tensor, Sequence[int]] – yields tuple of patch and location

get_input_shape(inputs)[source]#

Return the input spatial shape.

Parameters:

inputs (Any) – either a tensor of shape BCHW[D], representing a batch of images, or a filename (str) or list of filenames to the image(s).

Return type:

tuple

Returns:

spatial_shape

Mergers#

class monai.inferers.Merger(merged_shape, cropped_shape=None, device=None)[source]#

A base class for merging patches. Extend this class to support operations for PatchInference. There are two methods that must be implemented in the concrete classes:

  • aggregate: aggregate the values at their corresponding locations

  • finalize: perform any final process and return the merged output

Parameters:
  • merged_shape – the shape of the tensor required to merge the patches.

  • cropped_shape – the shape of the final merged output tensor. If not provided, it will be the same as merged_shape.

  • device – the device where Merger tensors should reside.

abstract aggregate(values, location)[source]#

Aggregate values for merging. This method is being called in a loop and should add values to their corresponding location in the merged output results.

Parameters:
  • values (Tensor) – a tensor of shape BCHW[D], representing the values of inference output.

  • location (Sequence[int]) – a tuple/list giving the top left location of the patch in the output.

Raises:

NotImplementedError – When the subclass does not override this method.

Return type:

None

abstract finalize()[source]#

Perform final operations for merging patches and return the final merged output.

Return type:

Any

Returns:

The results of merged patches, which is commonly a torch.Tensor representing the merged result, or

a string representing the filepath to the merged results on disk.

Raises:

NotImplementedError – When the subclass does not override this method.

AvgMerger#

class monai.inferers.AvgMerger(merged_shape, cropped_shape=None, value_dtype=torch.float32, count_dtype=torch.uint8, device='cpu')[source]#

Merge patches by taking average of the overlapping area

Parameters:
  • merged_shape – the shape of the tensor required to merge the patches.

  • cropped_shape – the shape of the final merged output tensor. If not provided, it will be the same as merged_shape.

  • device – the device for aggregator tensors and final results.

  • value_dtype – the dtype for value aggregating tensor and the final result.

  • count_dtype – the dtype for sample counting tensor.

aggregate(values, location)[source]#

Aggregate values for merging.

Parameters:
  • values (Tensor) – a tensor of shape BCHW[D], representing the values of inference output.

  • location (Sequence[int]) – a tuple/list giving the top left location of the patch in the original image.

Raises:

NotImplementedError – When the subclass does not override this method.

Return type:

None

finalize()[source]#

Finalize merging by dividing values by counts and return the merged tensor.

Notes

To avoid creating a new tensor for the final results (to save memory space), after this method is called, get_values() method will return the “final” averaged values, and not the accumulating values. Also calling finalize() multiple times does not have any effect.

Returns:

a tensor of merged patches

Return type:

torch.tensor

get_counts()[source]#

Get the aggregator tensor for number of samples.

Returns:

number of accumulated samples at each location.

Return type:

torch.Tensor

get_output()[source]#

Get the final merged output.

Returns:

merged output.

Return type:

torch.Tensor

get_values()[source]#

Get the accumulated values during aggregation or final averaged values after it is finalized.

Returns:

aggregated values.

Return type:

torch.tensor

Notes

  • If called before calling finalize(), this method returns the accumulating values.

  • If called after calling finalize(), this method returns the final merged [and averaged] values.

ZarrAvgMerger#

class monai.inferers.ZarrAvgMerger(merged_shape, cropped_shape=None, dtype='float32', value_dtype='float32', count_dtype='uint8', store='merged.zarr', value_store=None, count_store=None, compressor='default', value_compressor=None, count_compressor=None, chunks=True, thread_locking=True)[source]#

Merge patches by taking average of the overlapping area and store the results in zarr array.

Zarr is a format for the storage of chunked, compressed, N-dimensional arrays. Zarr data can be stored in any storage system that can be represented as a key-value store, like POSIX file systems, cloud object storage, zip files, and relational and document databases. See https://zarr.readthedocs.io/en/stable/ for more details. It is particularly useful for storing N-dimensional arrays too large to fit into memory. One specific use case of this class is to merge patches extracted from whole slide images (WSI), where the merged results do not fit into memory and need to be stored on a file system.

Parameters:
  • merged_shape – the shape of the tensor required to merge the patches.

  • cropped_shape – the shape of the final merged output tensor. If not provided, it will be the same as merged_shape.

  • dtype – the dtype for the final merged result. Default is float32.

  • value_dtype – the dtype for value aggregating tensor and the final result. Default is float32.

  • count_dtype – the dtype for sample counting tensor. Default is uint8.

  • store – the zarr store to save the final results. Default is “merged.zarr”.

  • value_store – the zarr store to save the value aggregating tensor. Default is a temporary store.

  • count_store – the zarr store to save the sample counting tensor. Default is a temporary store.

  • compressor – the compressor for final merged zarr array. Default is “default”.

  • value_compressor – the compressor for value aggregating zarr array. Default is None.

  • count_compressor – the compressor for sample counting zarr array. Default is None.

  • chunks – int or tuple of ints that defines the chunk shape, or boolean. Default is True. If True, chunk shape will be guessed from shape and dtype. If False, it will be set to shape, i.e., single chunk for the whole array. If an int, the chunk size in each dimension will be given by the value of chunks.

aggregate(values, location)[source]#

Aggregate values for merging.

Parameters:
  • values (Tensor) – a tensor of shape BCHW[D], representing the values of inference output.

  • location (Sequence[int]) – a tuple/list giving the top left location of the patch in the original image.

Return type:

None

finalize()[source]#

Finalize merging by dividing values by counts and return the merged tensor.

Notes

To avoid creating a new tensor for the final results (to save memory space), after this method is called, get_values() method will return the “final” averaged values, and not the accumulating values. Also calling finalize() multiple times does not have any effect.

Returns:

a zarr array of of merged patches

Return type:

zarr.Array

get_counts()[source]#

Get the aggregator tensor for number of samples.

Returns:

Number of accumulated samples at each location.

Return type:

zarr.Array

get_output()[source]#

Get the final merged output.

Returns:

Merged (averaged) output tensor.

Return type:

zarr.Array

get_values()[source]#

Get the accumulated values during aggregation

Returns:

aggregated values.

Return type:

zarr.Array

Sliding Window Inference Function#

monai.inferers.sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, overlap=0.25, mode=BlendMode.CONSTANT, sigma_scale=0.125, padding_mode=PytorchPadMode.CONSTANT, cval=0.0, sw_device=None, device=None, progress=False, roi_weight_map=None, process_fn=None, buffer_steps=None, buffer_dim=-1, with_coord=False, *args, **kwargs)[source]#

Sliding window inference on inputs with predictor.

The outputs of predictor could be a tensor, a tuple, or a dictionary of tensors. Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes could be ([128,64,256], [64,32,128]). In this case, the parameter overlap and roi_size need to be carefully chosen to ensure the output ROI is still an integer. If the predictor’s input and output spatial sizes are not equal, we recommend choosing the parameters so that overlap*roi_size*output_size/input_size is an integer (for each spatial dimension).

When roi_size is larger than the inputs’ spatial size, the input image are padded during inference. To maintain the same spatial sizes, the output image will be cropped to the original input size.

Parameters:
  • inputs – input image to be processed (assuming NCHW[D])

  • roi_size – the spatial window size for inferences. When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the roi_size are non-positive values, the transform will use the corresponding components of img size. For example, roi_size=(32, -1) will be adapted to (32, 64) if the second spatial dimension size of img is 64.

  • sw_batch_size – the batch size to run window slices.

  • predictor – given input tensor patch_data in shape NCHW[D], The outputs of the function call predictor(patch_data) should be a tensor, a tuple, or a dictionary with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM’H’W’[D’]; where H’W’[D’] represents the output patch’s spatial size, M is the number of output channels, N is sw_batch_size, e.g., the input shape is (7, 1, 128,128,128), the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). In this case, the parameter overlap and roi_size need to be carefully chosen to ensure the scaled output ROI sizes are still integers. If the predictor’s input and output spatial sizes are different, we recommend choosing the parameters so that overlap*roi_size*zoom_scale is an integer for each dimension.

  • overlap – Amount of overlap between scans along each spatial dimension, defaults to 0.25.

  • mode

    {"constant", "gaussian"} How to blend output of overlapping windows. Defaults to "constant".

    • "constant”: gives equal weight to all predictions.

    • "gaussian”: gives less weight to predictions on edges of windows.

  • sigma_scale – the standard deviation coefficient of the Gaussian window when mode is "gaussian". Default: 0.125. Actual window sigma is sigma_scale * dim_size. When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding spatial dimensions.

  • padding_mode – {"constant", "reflect", "replicate", "circular"} Padding mode for inputs, when roi_size is larger than inputs. Defaults to "constant" See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

  • cval – fill value for ‘constant’ padding mode. Default: 0

  • sw_device – device for the window data. By default the device (and accordingly the memory) of the inputs is used. Normally sw_device should be consistent with the device where predictor is defined.

  • device – device for the stitched output prediction. By default the device (and accordingly the memory) of the inputs is used. If for example set to device=torch.device(‘cpu’) the gpu memory consumption is less and independent of the inputs and roi_size. Output is on the device.

  • progress – whether to print a tqdm progress bar.

  • roi_weight_map – pre-computed (non-negative) weight map for each ROI. If not given, and mode is not constant, this map will be computed on the fly.

  • process_fn – process inference output and adjust the importance map per window

  • buffer_steps – the number of sliding window iterations along the buffer_dim to be buffered on sw_device before writing to device. (Typically, sw_device is cuda and device is cpu.) default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.

  • buffer_dim – the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension.

  • with_coord – whether to pass the window coordinates to predictor. Default is False. If True, the signature of predictor should be predictor(patch_data, patch_coord, ...).

  • args – optional args to be passed to predictor.

  • kwargs – optional keyword args to be passed to predictor.

Note

  • input must be channel-first and have a batch dim, supports N-D sliding window.