# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _TorchDistributedSampler
__all__ = ["DistributedSampler", "DistributedWeightedRandomSampler"]
[docs]
class DistributedSampler(_TorchDistributedSampler):
"""
Enhance PyTorch DistributedSampler to support non-evenly divisible sampling.
Args:
dataset: Dataset used for sampling.
even_divisible: if False, different ranks can have different data length.
for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].
num_replicas: number of processes participating in distributed training.
by default, `world_size` is retrieved from the current distributed group.
rank: rank of the current process within `num_replicas`. by default,
`rank` is retrieved from the current distributed group.
shuffle: if `True`, sampler will shuffle the indices, default to True.
kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`.
More information about DistributedSampler, please check:
https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.
"""
def __init__(
self,
dataset: Dataset,
even_divisible: bool = True,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
**kwargs,
):
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, **kwargs)
if not even_divisible:
data_len = len(dataset) # type: ignore
if data_len < self.num_replicas:
raise ValueError("the dataset length is less than the number of participating ranks.")
extra_size = self.total_size - data_len
if self.rank + extra_size >= self.num_replicas:
self.num_samples -= 1
self.total_size = data_len
[docs]
class DistributedWeightedRandomSampler(DistributedSampler):
"""
Extend the `DistributedSampler` to support weighted sampling.
Refer to `torch.utils.data.WeightedRandomSampler`, for more details please check:
https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler.
Args:
dataset: Dataset used for sampling.
weights: a sequence of weights, not necessary summing up to one, length should exactly
match the full dataset.
num_samples_per_rank: number of samples to draw for every rank, sample from
the distributed subset of dataset.
if None, default to the length of dataset split by DistributedSampler.
generator: PyTorch Generator used in sampling.
even_divisible: if False, different ranks can have different data length.
for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].'
num_replicas: number of processes participating in distributed training.
by default, `world_size` is retrieved from the current distributed group.
rank: rank of the current process within `num_replicas`. by default,
`rank` is retrieved from the current distributed group.
kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`.
"""
def __init__(
self,
dataset: Dataset,
weights: Sequence[float],
num_samples_per_rank: int | None = None,
generator: torch.Generator | None = None,
even_divisible: bool = True,
num_replicas: int | None = None,
rank: int | None = None,
**kwargs,
):
kwargs.setdefault("shuffle", True)
super().__init__(dataset=dataset, even_divisible=even_divisible, num_replicas=num_replicas, rank=rank, **kwargs)
self.weights = weights
self.num_samples_per_rank = num_samples_per_rank if num_samples_per_rank is not None else self.num_samples
self.generator = generator
def __iter__(self):
indices = list(super().__iter__())
weights = torch.as_tensor([self.weights[i] for i in indices], dtype=torch.double)
# sample based on the provided weights
rand_tensor = torch.multinomial(weights, self.num_samples_per_rank, True, generator=self.generator)
for i in rand_tensor:
yield indices[i]
def __len__(self):
return self.num_samples_per_rank