# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import random
from enum import Enum
from monai.utils import deprecated
__all__ = [
"StrEnum",
"NumpyPadMode",
"GridSampleMode",
"SplineMode",
"InterpolateMode",
"UpsampleMode",
"BlendMode",
"PytorchPadMode",
"NdimageMode",
"GridSamplePadMode",
"Average",
"MetricReduction",
"LossReduction",
"DiceCEReduction",
"Weight",
"ChannelMatching",
"SkipMode",
"Method",
"TraceKeys",
"TraceStatusKeys",
"CommonKeys",
"GanKeys",
"PostFix",
"ForwardMode",
"TransformBackends",
"CompInitMode",
"BoxModeName",
"GridPatchSort",
"FastMRIKeys",
"SpaceKeys",
"MetaKeys",
"ColorOrder",
"EngineStatsKeys",
"DataStatsKeys",
"ImageStatsKeys",
"LabelStatsKeys",
"AlgoEnsembleKeys",
"HoVerNetMode",
"HoVerNetBranch",
"LazyAttr",
"BundleProperty",
"BundlePropertyConfig",
"AlgoKeys",
]
[docs]
class StrEnum(str, Enum):
"""
Enum subclass that converts its value to a string.
.. code-block:: python
from monai.utils import StrEnum
class Example(StrEnum):
MODE_A = "A"
MODE_B = "B"
assert (list(Example) == ["A", "B"])
assert Example.MODE_A == "A"
assert str(Example.MODE_A) == "A"
assert monai.utils.look_up_option("A", Example) == "A"
"""
def __str__(self):
return self.value
def __repr__(self):
return self.value
[docs]
class NumpyPadMode(StrEnum):
"""
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""
CONSTANT = "constant"
EDGE = "edge"
LINEAR_RAMP = "linear_ramp"
MAXIMUM = "maximum"
MEAN = "mean"
MEDIAN = "median"
MINIMUM = "minimum"
REFLECT = "reflect"
SYMMETRIC = "symmetric"
WRAP = "wrap"
EMPTY = "empty"
[docs]
class NdimageMode(StrEnum):
"""
The available options determine how the input array is extended beyond its boundaries when interpolating.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
"""
REFLECT = "reflect"
GRID_MIRROR = "grid-mirror"
CONSTANT = "constant"
GRID_CONSTANT = "grid-constant"
NEAREST = "nearest"
MIRROR = "mirror"
GRID_WRAP = "grid-wrap"
WRAP = "wrap"
[docs]
class GridSampleMode(StrEnum):
"""
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
interpolation mode of `torch.nn.functional.grid_sample`
Note:
(documentation from `torch.nn.functional.grid_sample`)
`mode='bicubic'` supports only 4-D input.
When `mode='bilinear'` and the input is 5-D, the interpolation mode used internally will actually be trilinear.
However, when the input is 4-D, the interpolation mode will legitimately be bilinear.
"""
NEAREST = "nearest"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
[docs]
class SplineMode(StrEnum):
"""
Order of spline interpolation.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
"""
ZERO = 0
ONE = 1
TWO = 2
THREE = 3
FOUR = 4
FIVE = 5
[docs]
class InterpolateMode(StrEnum):
"""
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
"""
NEAREST = "nearest"
NEAREST_EXACT = "nearest-exact"
LINEAR = "linear"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
TRILINEAR = "trilinear"
AREA = "area"
[docs]
class UpsampleMode(StrEnum):
"""
See also: :py:class:`monai.networks.blocks.UpSample`
"""
DECONV = "deconv"
DECONVGROUP = "deconvgroup"
NONTRAINABLE = "nontrainable" # e.g. using torch.nn.Upsample
PIXELSHUFFLE = "pixelshuffle"
[docs]
class BlendMode(StrEnum):
"""
See also: :py:class:`monai.data.utils.compute_importance_map`
"""
CONSTANT = "constant"
GAUSSIAN = "gaussian"
[docs]
class PytorchPadMode(StrEnum):
"""
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
"""
CONSTANT = "constant"
REFLECT = "reflect"
REPLICATE = "replicate"
CIRCULAR = "circular"
[docs]
class GridSamplePadMode(StrEnum):
"""
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
"""
ZEROS = "zeros"
BORDER = "border"
REFLECTION = "reflection"
[docs]
class Average(StrEnum):
"""
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc`
"""
MACRO = "macro"
WEIGHTED = "weighted"
MICRO = "micro"
NONE = "none"
[docs]
class MetricReduction(StrEnum):
"""
See also: :py:func:`monai.metrics.utils.do_metric_reduction`
"""
NONE = "none"
MEAN = "mean"
SUM = "sum"
MEAN_BATCH = "mean_batch"
SUM_BATCH = "sum_batch"
MEAN_CHANNEL = "mean_channel"
SUM_CHANNEL = "sum_channel"
[docs]
class LossReduction(StrEnum):
"""
See also:
- :py:class:`monai.losses.dice.DiceLoss`
- :py:class:`monai.losses.dice.GeneralizedDiceLoss`
- :py:class:`monai.losses.focal_loss.FocalLoss`
- :py:class:`monai.losses.tversky.TverskyLoss`
"""
NONE = "none"
MEAN = "mean"
SUM = "sum"
[docs]
class DiceCEReduction(StrEnum):
"""
See also:
- :py:class:`monai.losses.dice.DiceCELoss`
"""
MEAN = "mean"
SUM = "sum"
[docs]
class Weight(StrEnum):
"""
See also: :py:class:`monai.losses.dice.GeneralizedDiceLoss`
"""
SQUARE = "square"
SIMPLE = "simple"
UNIFORM = "uniform"
[docs]
class ChannelMatching(StrEnum):
"""
See also: :py:class:`monai.networks.nets.HighResBlock`
"""
PAD = "pad"
PROJECT = "project"
[docs]
class SkipMode(StrEnum):
"""
See also: :py:class:`monai.networks.layers.SkipConnection`
"""
CAT = "cat"
ADD = "add"
MUL = "mul"
[docs]
class Method(StrEnum):
"""
See also: :py:class:`monai.transforms.croppad.array.SpatialPad`
"""
SYMMETRIC = "symmetric"
END = "end"
[docs]
class ForwardMode(StrEnum):
"""
See also: :py:class:`monai.transforms.engines.evaluator.Evaluator`
"""
TRAIN = "train"
EVAL = "eval"
[docs]
class TraceKeys(StrEnum):
"""Extra metadata keys used for traceable transforms."""
CLASS_NAME: str = "class"
ID: str = "id"
ORIG_SIZE: str = "orig_size"
EXTRA_INFO: str = "extra_info"
DO_TRANSFORM: str = "do_transforms"
KEY_SUFFIX: str = "_transforms"
NONE: str = "none"
TRACING: str = "tracing"
STATUSES: str = "statuses"
LAZY: str = "lazy"
[docs]
class TraceStatusKeys(StrEnum):
"""Enumerable status keys for the TraceKeys.STATUS flag"""
PENDING_DURING_APPLY = "pending_during_apply"
[docs]
class CommonKeys(StrEnum):
"""
A set of common keys for dictionary based supervised training process.
`IMAGE` is the input image data.
`LABEL` is the training or evaluation label of segmentation or classification task.
`PRED` is the prediction data of model output.
`LOSS` is the loss value of current iteration.
`INFO` is some useful information during training or evaluation, like loss value, etc.
"""
IMAGE = "image"
LABEL = "label"
PRED = "pred"
LOSS = "loss"
METADATA = "metadata"
[docs]
class GanKeys(StrEnum):
"""
A set of common keys for generative adversarial networks.
"""
REALS = "reals"
FAKES = "fakes"
LATENTS = "latents"
GLOSS = "g_loss"
DLOSS = "d_loss"
[docs]
class PostFix(StrEnum):
"""Post-fixes."""
@staticmethod
def _get_str(prefix: str | None, suffix: str) -> str:
return suffix if prefix is None else f"{prefix}_{suffix}"
@staticmethod
def meta(key: str | None = None) -> str:
return PostFix._get_str(key, "meta_dict")
@staticmethod
def orig_meta(key: str | None = None) -> str:
return PostFix._get_str(key, "orig_meta_dict")
@staticmethod
def transforms(key: str | None = None) -> str:
return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:])
[docs]
class CompInitMode(StrEnum):
"""
Mode names for instantiating a class or calling a callable.
See also: :py:func:`monai.utils.module.instantiate`
"""
DEFAULT = "default"
PARTIAL = "partial"
DEBUG = "debug"
class JITMetadataKeys(StrEnum):
"""
Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines
and others are optionally provided by users.
"""
NAME = "name"
TIMESTAMP = "timestamp"
VERSION = "version"
DESCRIPTION = "description"
[docs]
class BoxModeName(StrEnum):
"""
Box mode names.
"""
XYXY = "xyxy" # [xmin, ymin, xmax, ymax]
XYZXYZ = "xyzxyz" # [xmin, ymin, zmin, xmax, ymax, zmax]
XXYY = "xxyy" # [xmin, xmax, ymin, ymax]
XXYYZZ = "xxyyzz" # [xmin, xmax, ymin, ymax, zmin, zmax]
XYXYZZ = "xyxyzz" # [xmin, ymin, xmax, ymax, zmin, zmax]
XYWH = "xywh" # [xmin, ymin, xsize, ysize]
XYZWHD = "xyzwhd" # [xmin, ymin, zmin, xsize, ysize, zsize]
CCWH = "ccwh" # [xcenter, ycenter, xsize, ysize]
CCCWHD = "cccwhd" # [xcenter, ycenter, zcenter, xsize, ysize, zsize]
class ProbMapKeys(StrEnum):
"""
The keys to be used for generating the probability maps from patches
"""
LOCATION = "mask_location"
SIZE = "mask_size"
COUNT = "num_patches"
NAME = "name"
[docs]
class GridPatchSort(StrEnum):
"""
The sorting method for the generated patches in `GridPatch`
"""
RANDOM = "random"
MIN = "min"
MAX = "max"
@staticmethod
def min_fn(x):
return x[0].sum()
@staticmethod
def max_fn(x):
return -x[0].sum()
@staticmethod
def get_sort_fn(sort_fn):
if sort_fn == GridPatchSort.RANDOM:
return random.random
elif sort_fn == GridPatchSort.MIN:
return GridPatchSort.min_fn
elif sort_fn == GridPatchSort.MAX:
return GridPatchSort.max_fn
else:
raise ValueError(
f'sort_fn should be one of the following values, "{sort_fn}" was given:',
[e.value for e in GridPatchSort],
)
class PatchKeys(StrEnum):
"""
The keys to be used for metadata of patches extracted from any kind of image
"""
LOCATION = "location"
SIZE = "size"
COUNT = "count"
class WSIPatchKeys(StrEnum):
"""
The keys to be used for metadata of patches extracted from whole slide images
"""
LOCATION = PatchKeys.LOCATION
SIZE = PatchKeys.SIZE
COUNT = PatchKeys.COUNT
LEVEL = "level"
PATH = "path"
[docs]
class FastMRIKeys(StrEnum):
"""
The keys to be used for extracting data from the fastMRI dataset
"""
KSPACE = "kspace"
MASK = "mask"
FILENAME = "filename"
RECON = "reconstruction_rss"
ACQUISITION = "acquisition"
MAX = "max"
NORM = "norm"
PID = "patient_id"
[docs]
class SpaceKeys(StrEnum):
"""
The coordinate system keys, for example, Nifti1 uses Right-Anterior-Superior or "RAS",
DICOM (0020,0032) uses Left-Posterior-Superior or "LPS". This type does not distinguish spatial 1/2/3D.
"""
RAS = "RAS"
LPS = "LPS"
[docs]
class ColorOrder(StrEnum):
"""
Enums for color order. Expand as necessary.
"""
RGB = "RGB"
BGR = "BGR"
[docs]
class EngineStatsKeys(StrEnum):
"""
Default keys for the statistics of trainer and evaluator engines.
"""
RANK = "rank"
CURRENT_ITERATION = "current_iteration"
CURRENT_EPOCH = "current_epoch"
TOTAL_EPOCHS = "total_epochs"
TOTAL_ITERATIONS = "total_iterations"
BEST_VALIDATION_EPOCH = "best_validation_epoch"
BEST_VALIDATION_METRIC = "best_validation_metric"
[docs]
class DataStatsKeys(StrEnum):
"""
Defaults keys for dataset statistical analysis modules
"""
SUMMARY = "stats_summary"
BY_CASE = "stats_by_cases"
BY_CASE_IMAGE_PATH = "image_filepath"
BY_CASE_LABEL_PATH = "label_filepath"
IMAGE_STATS = "image_stats"
FG_IMAGE_STATS = "image_foreground_stats"
LABEL_STATS = "label_stats"
IMAGE_HISTOGRAM = "image_histogram"
[docs]
class ImageStatsKeys(StrEnum):
"""
Defaults keys for dataset statistical analysis image modules
"""
SHAPE = "shape"
CHANNELS = "channels"
CROPPED_SHAPE = "cropped_shape"
SPACING = "spacing"
SIZEMM = "sizemm"
INTENSITY = "intensity"
HISTOGRAM = "histogram"
[docs]
class LabelStatsKeys(StrEnum):
"""
Defaults keys for dataset statistical analysis label modules
"""
LABEL_UID = "labels"
PIXEL_PCT = "foreground_percentage"
IMAGE_INTST = "image_intensity"
LABEL = "label"
LABEL_SHAPE = "shape"
LABEL_NCOMP = "ncomponents"
[docs]
@deprecated(since="1.2", removed="1.4", msg_suffix="please use `AlgoKeys` instead.")
class AlgoEnsembleKeys(StrEnum):
"""
Default keys for Mixed Ensemble
"""
ID = "identifier"
ALGO = "infer_algo"
SCORE = "best_metric"
[docs]
class HoVerNetMode(StrEnum):
"""
Modes for HoVerNet model:
`FAST`: a faster implementation (than original)
`ORIGINAL`: the original implementation
"""
FAST = "FAST"
ORIGINAL = "ORIGINAL"
[docs]
class HoVerNetBranch(StrEnum):
"""
Three branches of HoVerNet model, which results in three outputs:
`HV` is horizontal and vertical gradient map of each nucleus (regression),
`NP` is the pixel prediction of all nuclei (segmentation), and
`NC` is the type of each nucleus (classification).
"""
HV = "horizontal_vertical"
NP = "nucleus_prediction"
NC = "type_prediction"
[docs]
class LazyAttr(StrEnum):
"""
MetaTensor with pending operations requires some key attributes tracked especially when the primary array
is not up-to-date due to lazy evaluation.
This class specifies the set of key attributes to be tracked for each MetaTensor.
See also: :py:func:`monai.transforms.lazy.utils.resample` for more details.
"""
SHAPE = "lazy_shape" # spatial shape
AFFINE = "lazy_affine"
PADDING_MODE = "lazy_padding_mode"
INTERP_MODE = "lazy_interpolation_mode"
DTYPE = "lazy_dtype"
ALIGN_CORNERS = "lazy_align_corners"
RESAMPLE_MODE = "lazy_resample_mode"
[docs]
class BundleProperty(StrEnum):
"""
Bundle property fields:
`DESC` is the description of the property.
`REQUIRED` is flag to indicate whether the property is required or optional.
"""
DESC = "description"
REQUIRED = "required"
[docs]
class BundlePropertyConfig(StrEnum):
"""
additional bundle property fields for config based bundle workflow:
`ID` is the config item ID of the property.
`REF_ID` is the ID of config item which is supposed to refer to this property.
For properties that do not have `REF_ID`, `None` should be set.
this field is only useful to check the optional property ID.
"""
ID = "id"
REF_ID = "refer_id"
[docs]
class AlgoKeys(StrEnum):
"""
Default keys for templated Auto3DSeg Algo.
`ID` is the identifier of the algorithm. The string has the format of <name>_<idx>_<other>.
`ALGO` is the Auto3DSeg Algo instance.
`IS_TRAINED` is the status that shows if the Algo has been trained.
`SCORE` is the score the Algo has achieved after training.
"""
ID = "identifier"
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"