Source code for monai.data.samplers

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import torch
from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _TorchDistributedSampler

__all__ = ["DistributedSampler", "DistributedWeightedRandomSampler"]


[docs] class DistributedSampler(_TorchDistributedSampler): """ Enhance PyTorch DistributedSampler to support non-evenly divisible sampling. Args: dataset: Dataset used for sampling. even_divisible: if False, different ranks can have different data length. for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4]. num_replicas: number of processes participating in distributed training. by default, `world_size` is retrieved from the current distributed group. rank: rank of the current process within `num_replicas`. by default, `rank` is retrieved from the current distributed group. shuffle: if `True`, sampler will shuffle the indices, default to True. kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`. More information about DistributedSampler, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler. """ def __init__( self, dataset: Dataset, even_divisible: bool = True, num_replicas: int | None = None, rank: int | None = None, shuffle: bool = True, **kwargs, ): super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, **kwargs) if not even_divisible: data_len = len(dataset) # type: ignore if data_len < self.num_replicas: raise ValueError("the dataset length is less than the number of participating ranks.") extra_size = self.total_size - data_len if self.rank + extra_size >= self.num_replicas: self.num_samples -= 1 self.total_size = data_len
[docs] class DistributedWeightedRandomSampler(DistributedSampler): """ Extend the `DistributedSampler` to support weighted sampling. Refer to `torch.utils.data.WeightedRandomSampler`, for more details please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler. Args: dataset: Dataset used for sampling. weights: a sequence of weights, not necessary summing up to one, length should exactly match the full dataset. num_samples_per_rank: number of samples to draw for every rank, sample from the distributed subset of dataset. if None, default to the length of dataset split by DistributedSampler. generator: PyTorch Generator used in sampling. even_divisible: if False, different ranks can have different data length. for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].' num_replicas: number of processes participating in distributed training. by default, `world_size` is retrieved from the current distributed group. rank: rank of the current process within `num_replicas`. by default, `rank` is retrieved from the current distributed group. kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`. """ def __init__( self, dataset: Dataset, weights: Sequence[float], num_samples_per_rank: int | None = None, generator: torch.Generator | None = None, even_divisible: bool = True, num_replicas: int | None = None, rank: int | None = None, **kwargs, ): kwargs.setdefault("shuffle", True) super().__init__(dataset=dataset, even_divisible=even_divisible, num_replicas=num_replicas, rank=rank, **kwargs) self.weights = weights self.num_samples_per_rank = num_samples_per_rank if num_samples_per_rank is not None else self.num_samples self.generator = generator def __iter__(self): indices = list(super().__iter__()) weights = torch.as_tensor([self.weights[i] for i in indices], dtype=torch.double) # sample based on the provided weights rand_tensor = torch.multinomial(weights, self.num_samples_per_rank, True, generator=self.generator) for i in rand_tensor: yield indices[i] def __len__(self): return self.num_samples_per_rank