Source code for monai.data.dataloader

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
from torch.utils.data import DataLoader as _TorchDataLoader
from torch.utils.data import Dataset

from monai.data.meta_obj import get_track_meta
from monai.data.utils import list_data_collate, set_rnd, worker_init_fn

__all__ = ["DataLoader"]


[docs] class DataLoader(_TorchDataLoader): """ Provides an iterable over the given `dataset`. It inherits the PyTorch DataLoader and adds enhanced `collate_fn` and `worker_fn` by default. Although this class could be configured to be the same as `torch.utils.data.DataLoader`, its default configuration is recommended, mainly for the following extra features: - It handles MONAI randomizable objects with appropriate random state managements for deterministic behaviour. - It is aware of the patch-based transform (such as :py:class:`monai.transforms.RandSpatialCropSamplesDict`) samples for preprocessing with enhanced data collating behaviour. See: :py:class:`monai.transforms.Compose`. For more details about :py:class:`torch.utils.data.DataLoader`, please see: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader. For example, to construct a randomized dataset and iterate with the data loader: .. code-block:: python import torch from monai.data import DataLoader from monai.transforms import Randomizable class RandomDataset(torch.utils.data.Dataset, Randomizable): def __getitem__(self, index): return self.R.randint(0, 1000, (1,)) def __len__(self): return 16 dataset = RandomDataset() dataloader = DataLoader(dataset, batch_size=2, num_workers=4) for epoch in range(2): for i, batch in enumerate(dataloader): print(epoch, i, batch.data.numpy().flatten().tolist()) Args: dataset: dataset from which to load the data. num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) collate_fn: default to :py:func:`monai.data.utils.list_data_collate`. worker_init_fn: default to :py:func:`monai.data.utils.worker_init_fn`. kwargs: other parameters for PyTorch DataLoader. """ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if num_workers == 0: # when num_workers > 0, random states are determined by worker_init_fn # this is to make the behavior consistent when num_workers == 0 # torch.int64 doesn't work well on some versions of windows _g = torch.random.default_generator if kwargs.get("generator") is None else kwargs["generator"] init_seed = _g.initial_seed() _seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item() set_rnd(dataset, int(_seed)) _g.manual_seed(init_seed) if "collate_fn" not in kwargs: kwargs["collate_fn"] = list_data_collate if "worker_init_fn" not in kwargs: kwargs["worker_init_fn"] = worker_init_fn if ( "multiprocessing_context" in kwargs and kwargs["multiprocessing_context"] == "spawn" and not get_track_meta() ): warnings.warn( "Please be aware: Return type of the dataloader will not be a Tensor as expected but" " a MetaTensor instead! This is because 'spawn' creates a new process where _TRACK_META" " is initialized to True again. Context:_TRACK_META is set to False and" " multiprocessing_context to spawn" ) super().__init__(dataset=dataset, num_workers=num_workers, **kwargs)