Source code for

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
from import DataLoader as _TorchDataLoader
from import Dataset

from import get_track_meta
from import list_data_collate, set_rnd, worker_init_fn

__all__ = ["DataLoader"]

[docs] class DataLoader(_TorchDataLoader): """ Provides an iterable over the given `dataset`. It inherits the PyTorch DataLoader and adds enhanced `collate_fn` and `worker_fn` by default. Although this class could be configured to be the same as ``, its default configuration is recommended, mainly for the following extra features: - It handles MONAI randomizable objects with appropriate random state managements for deterministic behaviour. - It is aware of the patch-based transform (such as :py:class:`monai.transforms.RandSpatialCropSamplesDict`) samples for preprocessing with enhanced data collating behaviour. See: :py:class:`monai.transforms.Compose`. For more details about :py:class:``, please see: For example, to construct a randomized dataset and iterate with the data loader: .. code-block:: python import torch from import DataLoader from monai.transforms import Randomizable class RandomDataset(, Randomizable): def __getitem__(self, index): return self.R.randint(0, 1000, (1,)) def __len__(self): return 16 dataset = RandomDataset() dataloader = DataLoader(dataset, batch_size=2, num_workers=4) for epoch in range(2): for i, batch in enumerate(dataloader): print(epoch, i, Args: dataset: dataset from which to load the data. num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) collate_fn: default to :py:func:``. worker_init_fn: default to :py:func:``. kwargs: other parameters for PyTorch DataLoader. """ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if num_workers == 0: # when num_workers > 0, random states are determined by worker_init_fn # this is to make the behavior consistent when num_workers == 0 # torch.int64 doesn't work well on some versions of windows _g = torch.random.default_generator if kwargs.get("generator") is None else kwargs["generator"] init_seed = _g.initial_seed() _seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item() set_rnd(dataset, int(_seed)) _g.manual_seed(init_seed) if "collate_fn" not in kwargs: kwargs["collate_fn"] = list_data_collate if "worker_init_fn" not in kwargs: kwargs["worker_init_fn"] = worker_init_fn if ( "multiprocessing_context" in kwargs and kwargs["multiprocessing_context"] == "spawn" and not get_track_meta() ): warnings.warn( "Please be aware: Return type of the dataloader will not be a Tensor as expected but" " a MetaTensor instead! This is because 'spawn' creates a new process where _TRACK_META" " is initialized to True again. Context:_TRACK_META is set to False and" " multiprocessing_context to spawn" ) super().__init__(dataset=dataset, num_workers=num_workers, **kwargs)