Source code for monai.data.dataloader

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional

from torch.utils.data import DataLoader as _TorchDataLoader
from torch.utils.data import Dataset, Sampler

from monai.data.utils import list_data_collate, worker_init_fn

__all__ = ["DataLoader"]


[docs]class DataLoader(_TorchDataLoader): """Generates images/labels for train/validation/testing from dataset. It inherits from PyTorch DataLoader and adds callbacks for `collate` and `worker_fn`. Args: dataset: dataset from which to load the data. batch_size: how many samples per batch to load (default: ``1``). shuffle: set to ``True`` to have the data reshuffled at every epoch (default: ``False``). sampler: defines the strategy to draw samples from the dataset. If specified, :attr:`shuffle` must be ``False``. batch_sampler: like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below. drop_last: set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) timeout: if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``) multiprocessing_context: specify a valid start method for multi-processing. """ def __init__( self, dataset: Dataset, batch_size: int = 1, shuffle: bool = False, sampler: Optional[Sampler] = None, batch_sampler: Optional[Sampler] = None, num_workers: int = 0, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0.0, multiprocessing_context: Optional[Callable] = None, ) -> None: super().__init__( # type: ignore[call-overload] dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, )