Source code for monai.transforms.post.dictionary

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of dictionary-based wrappers around the "vanilla" transforms for model output tensors
defined in :py:class:`monai.transforms.utility.array`.

Class names are ended with 'd' to denote dictionary-based transforms.
"""

from typing import Optional, Union, Sequence, Callable

from monai.config import KeysCollection
from monai.utils import ensure_tuple_rep
from monai.transforms.compose import MapTransform
from monai.transforms.post.array import (
    SplitChannel,
    Activations,
    AsDiscrete,
    KeepLargestConnectedComponent,
    LabelToContour,
)


[docs]class SplitChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. All the input specified by `keys` should be splitted into same count of data. """ def __init__( self, keys: KeysCollection, output_postfixes: Sequence[str], to_onehot: Union[Sequence[bool], bool] = False, num_classes: Optional[Union[Sequence[int], int]] = None, ): """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` output_postfixes: the postfixes to construct keys to store split data. for example: if the key of input data is `pred` and split 2 classes, the output data keys will be: pred_(output_postfixes[0]), pred_(output_postfixes[1]) to_onehot: whether to convert the data to One-Hot format, default is False. it also can be a sequence of bool, each element corresponds to a key in ``keys``. num_classes: the class number used to convert to One-Hot format if `to_onehot` is True. it also can be a sequence of int, each element corresponds to a key in ``keys``. Raises: ValueError: must specify key postfixes to store splitted data. """ super().__init__(keys) self.output_postfixes = output_postfixes self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) self.num_classes = ensure_tuple_rep(num_classes, len(self.keys)) self.splitter = SplitChannel()
[docs] def __call__(self, data): d = dict(data) for idx, key in enumerate(self.keys): rets = self.splitter(d[key], self.to_onehot[idx], self.num_classes[idx]) assert len(self.output_postfixes) == len(rets), "count of split results must match output_postfixes." for i, r in enumerate(rets): d[f"{key}_{self.output_postfixes[i]}"] = r return d
[docs]class Activationsd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.AddActivations`. Add activation layers to the input data specified by `keys`. """ def __init__( self, keys: KeysCollection, sigmoid: Union[Sequence[bool], bool] = False, softmax: Union[Sequence[bool], bool] = False, other: Optional[Union[Sequence[Callable], Callable]] = None, ): """ Args: keys: keys of the corresponding items to model output and label. See also: :py:class:`monai.transforms.compose.MapTransform` sigmoid: whether to execute sigmoid function on model output before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. softmax: whether to execute softmax function on model output before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. other: callable function to execute other activation layers, for example: `other = lambda x: torch.tanh(x)`. it also can be a sequence of Callable, each element corresponds to a key in ``keys``. """ super().__init__(keys) self.sigmoid = ensure_tuple_rep(sigmoid, len(self.keys)) self.softmax = ensure_tuple_rep(softmax, len(self.keys)) self.other = ensure_tuple_rep(other, len(self.keys)) self.converter = Activations()
[docs] def __call__(self, data): d = dict(data) for idx, key in enumerate(self.keys): d[key] = self.converter(d[key], self.sigmoid[idx], self.softmax[idx], self.other[idx]) return d
[docs]class AsDiscreted(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ def __init__( self, keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, to_onehot: Union[Sequence[bool], bool] = False, n_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, ): """ Args: keys: keys of the corresponding items to model output and label. See also: :py:class:`monai.transforms.compose.MapTransform` argmax: whether to execute argmax function on input data before transform. it also can be a sequence of bool, each element corresponds to a key in ``keys``. to_onehot: whether to convert input data into the one-hot format. Defaults to False. it also can be a sequence of bool, each element corresponds to a key in ``keys``. n_classes: the number of classes to convert to One-Hot format. it also can be a sequence of int, each element corresponds to a key in ``keys``. threshold_values: whether threshold the float value to int number 0 or 1, default is False. it also can be a sequence of bool, each element corresponds to a key in ``keys``. logit_thresh: the threshold value for thresholding operation, default is 0.5. it also can be a sequence of float, each element corresponds to a key in ``keys``. """ super().__init__(keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) self.converter = AsDiscrete()
[docs] def __call__(self, data): d = dict(data) for idx, key in enumerate(self.keys): d[key] = self.converter( d[key], self.argmax[idx], self.to_onehot[idx], self.n_classes[idx], self.threshold_values[idx], self.logit_thresh[idx], ) return d
[docs]class KeepLargestConnectedComponentd(MapTransform): """ Dictionary-based wrapper of :py:class:monai.transforms.KeepLargestConnectedComponent. """ def __init__( self, keys: KeysCollection, applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None, ): """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` applied_labels: Labels for applying the connected component on. If only one channel. The pixel whose value is not in this list will remain unchanged. If the data is in one-hot format, this is the channel indexes to apply transform. independent: consider several labels as a whole or independent, default is `True`. Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case you want this "independent" to be specified as False. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ super().__init__(keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity)
[docs] def __call__(self, data): d = dict(data) for key in self.keys: d[key] = self.converter(d[key]) return d
[docs]class LabelToContourd(MapTransform): """ Dictionary-based wrapper of :py:class:monai.transforms.LabelToContour. """ def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace"): """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` kernel_type: the method applied to do edge detection, default is "Laplace". """ super().__init__(keys) self.converter = LabelToContour(kernel_type=kernel_type)
[docs] def __call__(self, data): d = dict(data) for key in self.keys: d[key] = self.converter(d[key]) return d
SplitChannelD = SplitChannelDict = SplitChanneld ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd LabelToContourD = LabelToContourDict = LabelToContourd