Source code for monai.transforms.croppad.batch

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of "vanilla" transforms for crop and pad operations acting on batches of data.
"""

from __future__ import annotations

from collections.abc import Hashable, Mapping
from typing import Any

import numpy as np
import torch

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import list_data_collate
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.utils.enums import Method, PytorchPadMode, TraceKeys

__all__ = ["PadListDataCollate"]


def replace_element(to_replace, batch, idx, key_or_idx):
    # since tuple is immutable we'll have to recreate
    if isinstance(batch[idx], tuple):
        batch_idx_list = list(batch[idx])
        batch_idx_list[key_or_idx] = to_replace
        batch[idx] = tuple(batch_idx_list)
    # else, replace
    else:
        batch[idx][key_or_idx] = to_replace
    return batch


[docs] class PadListDataCollate(InvertibleTransform): """ Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of different sizes. This can be used on both list and dictionary data. Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms. Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`. This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the `inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can pass the inverse through multiprocessing. Args: method: padding method (see :py:class:`monai.transforms.SpatialPad`) mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ def __init__(self, method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **kwargs) -> None: self.method = method self.mode = mode self.kwargs = kwargs
[docs] def __call__(self, batch: Any): """ Args: batch: batch of data to pad-collate """ # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) # loop over items inside of each element in a batch batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range(len(batch[0])) for key_or_idx in batch_item: # calculate max size of each dimension max_shapes = [] for elem in batch: if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays, else skip as no padding to be done if not max_shapes: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.kwargs) for idx, batch_i in enumerate(batch): orig_size = batch_i[key_or_idx].shape[1:] padded = padder(batch_i[key_or_idx]) batch = replace_element(padded, batch, idx, key_or_idx) # If we have a dictionary of data, append to list # padder transform info is re-added with self.push_transform to ensure one info dict per transform. if is_list_of_dicts: self.push_transform( batch[idx], key_or_idx, orig_size=orig_size, extra_info=self.pop_transform(batch[idx], key_or_idx, check=False), ) # After padding, use default list collator return list_data_collate(batch)
[docs] @staticmethod def inverse(data: dict) -> dict[Hashable, np.ndarray]: if not isinstance(data, Mapping): raise RuntimeError(f"Inverse can only currently be applied on dictionaries, got type {type(data)}.") d = dict(data) for key in d: transforms = None if isinstance(d[key], MetaTensor): transforms = d[key].applied_operations else: transform_key = InvertibleTransform.trace_key(key) if transform_key in d: transforms = d[transform_key] if not transforms or not isinstance(transforms[-1], dict): continue if transforms[-1].get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__: xform = transforms.pop() cropping = CenterSpatialCrop(xform.get(TraceKeys.ORIG_SIZE, -1)) with cropping.trace_transform(False): d[key] = cropping(d[key]) # fallback to image size return d