Source code for

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Callable

from import DataLoader as _TorchDataLoader
from import list_data_collate, worker_init_fn

__all__ = ["DataLoader"]

[docs]class DataLoader(_TorchDataLoader): """Generates images/labels for train/validation/testing from dataset. It inherits from PyTorch DataLoader and adds callbacks for `collate` and `worker_fn`. Args: dataset (Dataset): dataset from which to load the data. batch_size: how many samples per batch to load (default: ``1``). shuffle: set to ``True`` to have the data reshuffled at every epoch (default: ``False``). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, :attr:`shuffle` must be ``False``. batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. num_workers: how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below. drop_last: set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``) multiprocessing_context: specify a valid start method for multi-processing. """ def __init__( self, dataset, batch_size: Optional[int] = 1, shuffle: bool = False, sampler=None, batch_sampler=None, num_workers: Optional[int] = 0, pin_memory=False, drop_last=False, timeout=0, multiprocessing_context: Optional[Callable] = None, ): super().__init__( dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, )