# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Hashable, Mapping, Sequence
from typing import Any
import numpy as np
import torch
from monai.config import KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.transforms.smooth_field.array import (
RandSmoothDeform,
RandSmoothFieldAdjustContrast,
RandSmoothFieldAdjustIntensity,
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, convert_to_tensor, ensure_tuple_rep
__all__ = [
"RandSmoothFieldAdjustContrastd",
"RandSmoothFieldAdjustIntensityd",
"RandSmoothDeformd",
"RandSmoothFieldAdjustContrastD",
"RandSmoothFieldAdjustIntensityD",
"RandSmoothDeformD",
"RandSmoothFieldAdjustContrastDict",
"RandSmoothFieldAdjustIntensityDict",
"RandSmoothDeformDict",
]
[docs]
class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform):
"""
Dictionary version of RandSmoothFieldAdjustContrast.
The field is randomized once per invocation by default so the same field is applied to every selected key. The
`mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with
one for each key in `keys`.
Args:
keys: key names to apply the augment to
spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions
rand_size: size of the randomized field to start from
pad: number of pixels/voxels along the edges of the field to pad with 0
mode: interpolation mode to use when upsampling
align_corners: if True align the corners when upsampling field
prob: probability transform is applied
gamma: (min, max) range for exponential field
device: Pytorch device to define field on
"""
backend = RandSmoothFieldAdjustContrast.backend
def __init__(
self,
keys: KeysCollection,
spatial_size: Sequence[int],
rand_size: Sequence[int],
pad: int = 0,
mode: SequenceStr = InterpolateMode.AREA,
align_corners: bool | None = None,
prob: float = 0.1,
gamma: Sequence[float] | float = (0.5, 4.5),
device: torch.device | None = None,
):
RandomizableTransform.__init__(self, prob)
MapTransform.__init__(self, keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.trans = RandSmoothFieldAdjustContrast(
spatial_size=spatial_size,
rand_size=rand_size,
pad=pad,
mode=self.mode[0],
align_corners=align_corners,
prob=1.0,
gamma=gamma,
device=device,
)
[docs]
def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandSmoothFieldAdjustContrastd:
super().set_random_state(seed, state)
self.trans.set_random_state(seed, state)
return self
[docs]
def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if self._do_transform:
self.trans.randomize()
[docs]
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
self.randomize()
d = dict(data)
if not self._do_transform:
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d
for idx, key in enumerate(self.key_iterator(d)):
self.trans.set_mode(self.mode[idx % len(self.mode)])
d[key] = self.trans(d[key], False)
return d
[docs]
class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform):
"""
Dictionary version of RandSmoothFieldAdjustIntensity.
The field is randomized once per invocation by default so the same field is applied to every selected key. The
`mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with
one for each key in `keys`.
Args:
keys: key names to apply the augment to
spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions
rand_size: size of the randomized field to start from
pad: number of pixels/voxels along the edges of the field to pad with 0
mode: interpolation mode to use when upsampling
align_corners: if True align the corners when upsampling field
prob: probability transform is applied
gamma: (min, max) range of intensity multipliers
device: Pytorch device to define field on
"""
backend = RandSmoothFieldAdjustIntensity.backend
def __init__(
self,
keys: KeysCollection,
spatial_size: Sequence[int],
rand_size: Sequence[int],
pad: int = 0,
mode: SequenceStr = InterpolateMode.AREA,
align_corners: bool | None = None,
prob: float = 0.1,
gamma: Sequence[float] | float = (0.1, 1.0),
device: torch.device | None = None,
):
RandomizableTransform.__init__(self, prob)
MapTransform.__init__(self, keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.trans = RandSmoothFieldAdjustIntensity(
spatial_size=spatial_size,
rand_size=rand_size,
pad=pad,
mode=self.mode[0],
align_corners=align_corners,
prob=1.0,
gamma=gamma,
device=device,
)
[docs]
def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandSmoothFieldAdjustIntensityd:
super().set_random_state(seed, state)
self.trans.set_random_state(seed, state)
return self
[docs]
def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
self.trans.randomize()
[docs]
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
self.randomize()
d = dict(data)
if not self._do_transform:
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d
for idx, key in enumerate(self.key_iterator(d)):
self.trans.set_mode(self.mode[idx % len(self.mode)])
d[key] = self.trans(d[key], False)
return d
RandSmoothDeformD = RandSmoothDeformDict = RandSmoothDeformd
RandSmoothFieldAdjustIntensityD = RandSmoothFieldAdjustIntensityDict = RandSmoothFieldAdjustIntensityd
RandSmoothFieldAdjustContrastD = RandSmoothFieldAdjustContrastDict = RandSmoothFieldAdjustContrastd