Source code for monai.data.dataset

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import torch

from monai.transforms.compose import Compose, Randomizable
from monai.transforms.utils import apply_transform
from monai.utils import process_bar


[docs]class Dataset(torch.utils.data.Dataset): """ Generic dataset to handle dictionary format data, it can operate transforms for specific fields. For example, typical input data can be a list of dictionaries:: [{ { { 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz', 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz', 'extra': 123 'extra': 456 'extra': 789 }, }, }] """ def __init__(self, data, transform=None): """ Args: data (Iterable): input data to load and transform to generate dataset for model. transform (Callable, optional): transforms to execute operations on input data. """ self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, index): data = self.data[index] if self.transform is not None: data = self.transform(data) return data
[docs]class CacheDataset(Dataset): """ Dataset with cache mechanism that can load data and cache deterministic transforms' result during training. By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline. If the requested data is not in the cache, all transforms will run normally (see also :py:class:`monai.data.dataset.Dataset`). Users can set the cache rate or number of items to cache. It is recommended to experiment with different `cache_num` or `cache_rate` to identify the best training speed. To improve the caching efficiency, please always put as many as possible non-random transforms before the randomised ones when composing the chain of transforms. For example, if the transform is a `Compose` of:: transforms = Compose([ LoadNiftid(), AddChanneld(), Spacingd(), Orientationd(), ScaleIntensityRanged(), RandCropByPosNegLabeld(), ToTensord() ]) when `transforms` is used in a multi-epoch training pipeline, before the first training epoch, this dataset will cache the results up to ``ScaleIntensityRanged``, as all non-random transforms `LoadNiftid`, `AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged` can be cached. During training, the dataset will load the cached results and run ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomised transform and the outcome not cached. """ def __init__(self, data, transform, cache_num=sys.maxsize, cache_rate=1.0): """ Args: data (Iterable): input data to load and transform to generate dataset for model. transform (Callable): transforms to execute operations on input data. cache_num (int): number of items to be cached. Default is `sys.maxsize`. will take the minimum of (cache_num, data_length x cache_rate, data_length). cache_rate (float): percentage of cached data in total, default is 1.0 (cache all). will take the minimum of (cache_num, data_length x cache_rate, data_length). """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data, transform) self.cache_num = min(cache_num, int(len(self) * cache_rate), len(self)) self._cache = list() print('Load and cache transformed data...') for i in range(self.cache_num): process_bar(i + 1, self.cache_num) item = data[i] for _transform in transform.transforms: # execute all the deterministic transforms before the first random transform if isinstance(_transform, Randomizable): break item = apply_transform(_transform, item) self._cache.append(item) def __getitem__(self, index): if index < self.cache_num: # load data from cache and execute from the first random transform start_run = False data = self._cache[index] for _transform in self.transform.transforms: if not start_run and not isinstance(_transform, Randomizable): continue else: start_run = True data = apply_transform(_transform, data) else: # no cache for this data, execute all the transforms directly data = super(CacheDataset, self).__getitem__(index) return data