# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
from typing import Any
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader as TorchDataLoader
from monai.config import KeysCollection
from monai.data.dataloader import DataLoader
from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Transform
from monai.utils import first
__all__ = ["BatchInverseTransform", "Decollated", "DecollateD", "DecollateDict"]
class _BatchInverseDataset(Dataset):
def __init__(self, data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool) -> None:
self.data = data
self.invertible_transform = transform
self.pad_collation_used = pad_collation_used
def __getitem__(self, index: int):
data = dict(self.data[index])
# If pad collation was used, then we need to undo this first
if self.pad_collation_used:
data = PadListDataCollate.inverse(data)
if not isinstance(self.invertible_transform, InvertibleTransform):
warnings.warn("transform is not invertible, can't invert transform for the input data.")
return data
return self.invertible_transform.inverse(data)
def __len__(self) -> int:
return len(self.data)
[docs]
class Decollated(MapTransform):
"""
Decollate a batch of data. If input is a dictionary, it also supports to only decollate specified keys.
Note that unlike most MapTransforms, it will delete the other keys that are not specified.
if `keys=None`, it will decollate all the data in the input.
It replicates the scalar values to every item of the decollated list.
Args:
keys: keys of the corresponding items to decollate, note that it will delete other keys not specified.
if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`.
detach: whether to detach the tensors. Scalars tensors will be detached into number types
instead of torch tensors.
pad_batch: when the items in a batch indicate different batch size,
whether to pad all the sequences to the longest.
If False, the batch size will be the length of the shortest sequence.
fill_value: the value to fill the padded sequences when `pad_batch=True`.
allow_missing_keys: don't raise exception if key is missing.
"""
def __init__(
self,
keys: KeysCollection | None = None,
detach: bool = True,
pad_batch: bool = True,
fill_value=None,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.detach = detach
self.pad_batch = pad_batch
self.fill_value = fill_value
def __call__(self, data: dict | list):
d: dict | list
if len(self.keys) == 1 and self.keys[0] is None:
# it doesn't support `None` as the key
d = data
else:
if not isinstance(data, dict):
raise TypeError("input data is not a dictionary, but specified keys to decollate.")
d = {}
for key in self.key_iterator(data):
d[key] = data[key]
return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)
DecollateD = DecollateDict = Decollated