# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader as TorchDataLoader
from monai.config import KeysCollection
from monai.data.dataloader import DataLoader
from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Transform
from monai.utils import first
__all__ = ["BatchInverseTransform", "Decollated"]
class _BatchInverseDataset(Dataset):
def __init__(
self,
data: Sequence[Any],
transform: InvertibleTransform,
pad_collation_used: bool,
) -> None:
self.data = data
self.invertible_transform = transform
self.pad_collation_used = pad_collation_used
def __getitem__(self, index: int):
data = dict(self.data[index])
# If pad collation was used, then we need to undo this first
if self.pad_collation_used:
data = PadListDataCollate.inverse(data)
if not isinstance(self.invertible_transform, InvertibleTransform):
warnings.warn("transform is not invertible, can't invert transform for the input data.")
return data
return self.invertible_transform.inverse(data)
def __len__(self) -> int:
return len(self.data)
[docs]class Decollated(MapTransform):
"""
Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys.
Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate
all the data in the input.
And it replicates the scalar values to every item of the decollated list.
Args:
keys: keys of the corresponding items to decollate, note that it will delete other keys not specified.
if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`.
detach: whether to detach the tensors. Scalars tensors will be detached into number types
instead of torch tensors.
allow_missing_keys: don't raise exception if key is missing.
"""
def __init__(
self,
keys: Optional[KeysCollection] = None,
detach: bool = True,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.detach = detach
def __call__(self, data: Union[Dict, List]):
d: Union[Dict, List]
if len(self.keys) == 1 and self.keys[0] is None:
# it doesn't support `None` as the key
d = data
else:
if not isinstance(data, dict):
raise TypeError("input data is not a dictionary, but specified keys to decollate.")
d = {}
for key in self.key_iterator(data):
d[key] = data[key]
return decollate_batch(rep_scalar_to_batch(d), detach=self.detach)