# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
from torch.utils.data import DataLoader as _TorchDataLoader
from torch.utils.data import Dataset, Sampler
from monai.data.utils import list_data_collate, worker_init_fn
__all__ = ["DataLoader"]
[docs]class DataLoader(_TorchDataLoader):
"""Generates images/labels for train/validation/testing from dataset.
It inherits from PyTorch DataLoader and adds callbacks for `collate` and `worker_fn`.
Args:
dataset: dataset from which to load the data.
batch_size: how many samples per batch to load
(default: ``1``).
shuffle: set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler: defines the strategy to draw samples from
the dataset. If specified, :attr:`shuffle` must be ``False``.
batch_sampler: like :attr:`sampler`, but returns a batch of
indices at a time. Mutually exclusive with :attr:`batch_size`,
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
pin_memory: If ``True``, the data loader will copy Tensors
into CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last: set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
timeout: if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
multiprocessing_context: specify a valid start method for multi-processing.
"""
def __init__(
self,
dataset: Dataset,
batch_size: int = 1,
shuffle: bool = False,
sampler: Optional[Sampler] = None,
batch_sampler: Optional[Sampler] = None,
num_workers: int = 0,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0.0,
multiprocessing_context: Optional[Callable] = None,
) -> None:
super().__init__( # type: ignore[call-overload]
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=list_data_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
)