# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Dict, Optional, Sequence, Union
import torch
from torch.utils.data import IterableDataset
from monai.data.dataset import Dataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple
__all__ = ["PatchDataset", "GridPatchDataset"]
[docs]class GridPatchDataset(IterableDataset):
"""
Yields patches from arrays read from an input dataset. The patches are chosen in a contiguous grid sampling scheme.
"""
def __init__(
self,
dataset: Sequence,
patch_size: Sequence[int],
start_pos: Sequence[int] = (),
mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
**pad_opts: Dict,
) -> None:
"""
Initializes this dataset in terms of the input dataset and patch size. The `patch_size` is the size of the
patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which
will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D
array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be
specified by a `patch_size` of (10, 10, 10).
Args:
dataset: the dataset to read array data from
patch_size: size of patches to generate slices for, 0/None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
pad_opts: padding options, see numpy.pad
"""
self.dataset = dataset
self.patch_size = (None,) + tuple(patch_size)
self.start_pos = ensure_tuple(start_pos)
self.mode: NumpyPadMode = NumpyPadMode(mode)
self.pad_opts = pad_opts
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
iter_start = 0
iter_end = len(self.dataset)
if worker_info is not None:
# split workload
per_worker = int(math.ceil((iter_end - iter_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = iter_start + worker_id * per_worker
iter_end = min(iter_start + per_worker, iter_end)
for index in range(iter_start, iter_end):
arrays = self.dataset[index]
iters = [iter_patch(a, self.patch_size, self.start_pos, False, self.mode, **self.pad_opts) for a in arrays]
yield from zip(*iters)
[docs]class PatchDataset(Dataset):
"""
returns a patch from an image dataset.
The patches are generated by a user-specified callable `patch_func`,
and are optionally post-processed by `transform`.
For example, to generate random patch samples from an image dataset:
.. code-block:: python
import numpy as np
from monai.data import PatchDataset, DataLoader
from monai.transforms import RandSpatialCropSamples, RandShiftIntensity
# image dataset
images = [np.arange(16, dtype=np.float).reshape(1, 4, 4),
np.arange(16, dtype=np.float).reshape(1, 4, 4)]
# image patch sampler
n_samples = 5
sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples,
random_center=True, random_size=False)
# patch-level intensity shifts
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
# construct the patch dataset
ds = PatchDataset(dataset=images,
patch_func=sampler,
samples_per_image=n_samples,
transform=patch_intensity)
# use the patch dataset, length: len(images) x samplers_per_image
print(len(ds))
>>> 10
for item in DataLoader(ds, batch_size=2, shuffle=True, num_workers=2):
print(item.shape)
>>> torch.Size([2, 1, 3, 3])
"""
def __init__(
self, dataset: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None
) -> None:
"""
Args:
dataset: an image dataset to extract patches from.
patch_func: converts an input image (item from dataset) into a sequence of image patches.
patch_func(dataset[idx]) must return a sequence of patches (length `samples_per_image`).
samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
transform: transform applied to each patch.
"""
super().__init__(data=dataset, transform=transform)
self.patch_func = patch_func
if samples_per_image <= 0:
raise ValueError("sampler_per_image must be a positive integer.")
self.samples_per_image = int(samples_per_image)
def __len__(self) -> int:
return len(self.data) * self.samples_per_image
def __getitem__(self, index: int):
image_id = int(index / self.samples_per_image)
image = self.data[image_id]
patches = self.patch_func(image)
if len(patches) != self.samples_per_image:
raise RuntimeWarning(
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
)
patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1)
patch = patches[patch_id]
if self.transform is not None:
patch = apply_transform(self.transform, patch, map_items=False)
return patch