Metrics

FROC

monai.metrics.compute_fp_tp_probs(probs, y_coord, x_coord, evaluation_mask, labels_to_exclude=None, resolution_level=0)[source]

This function is modified from the official evaluation code of CAMELYON 16 Challenge, and used to distinguish true positive and false positive predictions. A true positive prediction is defined when the detection point is within the annotated ground truth region.

Parameters
  • probs (Union[ndarray, Tensor]) – an array with shape (n,) that represents the probabilities of the detections. Where, n is the number of predicted detections.

  • y_coord (Union[ndarray, Tensor]) – an array with shape (n,) that represents the Y-coordinates of the detections.

  • x_coord (Union[ndarray, Tensor]) – an array with shape (n,) that represents the X-coordinates of the detections.

  • evaluation_mask (Union[ndarray, Tensor]) – the ground truth mask for evaluation.

  • labels_to_exclude (Optional[List]) – labels in this list will not be counted for metric calculation.

  • resolution_level (int) – the level at which the evaluation mask is made.

Returns

an array that contains the probabilities of the false positive detections. tp_probs: an array that contains the probabilities of the True positive detections. num_targets: the total number of targets (excluding labels_to_exclude) for all images under evaluation.

Return type

fp_probs

monai.metrics.compute_froc_curve_data(fp_probs, tp_probs, num_targets, num_images)[source]

This function is modified from the official evaluation code of CAMELYON 16 Challenge, and used to compute the required data for plotting the Free Response Operating Characteristic (FROC) curve.

Parameters
  • fp_probs (Union[ndarray, Tensor]) – an array that contains the probabilities of the false positive detections for all images under evaluation.

  • tp_probs (Union[ndarray, Tensor]) – an array that contains the probabilities of the True positive detections for all images under evaluation.

  • num_targets (int) – the total number of targets (excluding labels_to_exclude) for all images under evaluation.

  • num_images (int) – the number of images under evaluation.

monai.metrics.compute_froc_score(fps_per_image, total_sensitivity, eval_thresholds=(0.25, 0.5, 1, 2, 4, 8))[source]

This function is modified from the official evaluation code of CAMELYON 16 Challenge, and used to compute the challenge’s second evaluation metric, which is defined as the average sensitivity at the predefined false positive rates per whole slide image.

Parameters
  • fps_per_image (ndarray) – the average number of false positives per image for different thresholds.

  • total_sensitivity (ndarray) – sensitivities (true positive rates) for different thresholds.

  • eval_thresholds (Tuple) – the false positive rates for calculating the average sensitivity. Defaults to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.

Metric

class monai.metrics.Metric[source]

Base class for metric computation for evaluating the performance of a model. __call__ is designed to execute the computation.

IterationMetric

class monai.metrics.IterationMetric[source]

Base class for metrics computation at the iteration level, that is, on a min-batch of samples usually using the model outcome of one iteration.

__call__ is designed to handle y_pred and y (optional) in torch tensors or a list/tuple of tensors.

Subclasses typically implement the _compute_tensor function for the actual tensor computation logic.

Cumulative

class monai.metrics.Cumulative[source]

Utility class for the typical cumulative computation process based on PyTorch Tensors. It provides interfaces to accumulate values in the local buffers, synchronize buffers across distributed nodes, and aggregate the buffered values.

In multi-processing, PyTorch programs usually distribute data to multiple nodes. Each node runs with a subset of the data, adds values to its local buffers. Calling get_buffer could gather all the results and aggregate can further handle the results to generate the final outcomes.

Users can implement their own aggregate method to handle the results, using get_buffer to get the buffered contents.

Note: the data list should have the same length every time calling add() in a round, it will automatically create buffers according to the length of data list.

Typically, this class is expected to execute the following steps:

from monai.metrics import Cumulative

c = Cumulative()
c.append(1)  # adds a value
c.extend([2, 3])  # adds a batch of values
c.extend([4, 5, 6])  # adds a batch of values
print(c.get_buffer())  # tensor([1, 2, 3, 4, 5, 6])
print(len(c))  # 6
c.reset()
print(len(c))  # 0

The following is an example of maintaining two internal buffers:

from monai.metrics import Cumulative

c = Cumulative()
c.append(1, 2)  # adds a value to two buffers respectively
c.extend([3, 4], [5, 6])  # adds batches of values
print(c.get_buffer())  # [tensor([1, 3, 4]), tensor([2, 5, 6])]
print(len(c))

The following is an example of extending with variable length data:

import torch
from monai.metrics import Cumulative

c = Cumulative()
c.extend(torch.zeros((8, 2)), torch.zeros((6, 2)))  # adds batches
c.append(torch.zeros((2, )))  # adds a value
print(c.get_buffer())  # [torch.zeros((9, 2)), torch.zeros((6, 2))]
print(len(c))

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

__init__()[source]

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

abstract aggregate(*args, **kwargs)[source]

Aggregate final results based on the gathered buffers. This method is expected to use get_buffer to gather the local buffer contents.

append(*data)[source]

Add samples to the local cumulative buffers. A buffer will be allocated for each data item. Compared with self.extend, this method adds a single sample (instead of a “batch”) to the local buffers.

Parameters

data – each item will be converted into a torch tensor. they will be stacked at the 0-th dim with a new dimension when get_buffer() is called.

Return type

None

extend(*data)[source]

Extend the local buffers with new (“batch-first”) data. A buffer will be allocated for each data item. Compared with self.append, this method adds a “batch” of data to the local buffers.

Parameters

data – each item can be a “batch-first” tensor or a list of “channel-first” tensors. they will be concatenated at the 0-th dimension when get_buffer() is called.

Return type

None

get_buffer()[source]

Get the synchronized list of buffers. A typical usage is to generate the metrics report based on the raw metric details. Each buffer is a PyTorch Tensor.

reset()[source]

Reset the buffers for cumulative tensors and the synced results.

CumulativeIterationMetric

class monai.metrics.CumulativeIterationMetric[source]

Base class of cumulative metric which collects metrics on each mini-batch data at the iteration level.

Typically, it computes some intermediate results for each iteration, adds them to the buffers, then the buffer contents could be gathered and aggregated for the final result when epoch completed.

For example, MeanDice inherits this class and the usage is as follows:

dice_metric = DiceMetric(include_background=True, reduction="mean")

for val_data in val_loader:
    val_outputs = model(val_data["img"])
    val_outputs = [postprocessing_transform(i) for i in decollate_batch(val_outputs)]
    # compute metric for current iteration
    dice_metric(y_pred=val_outputs, y=val_data["seg"])  # callable to add metric to the buffer

# aggregate the final mean dice result
metric = dice_metric.aggregate().item()

# reset the status for next computation round
dice_metric.reset()

And to load predictions and labels from files, then compute metrics with multi-processing, please refer to: https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py.

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

Mean Dice

monai.metrics.compute_meandice(y_pred, y, include_background=True)[source]

Computes Dice score metric from full size Tensor and collects average.

Parameters
  • y_pred (Tensor) – input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized.

  • y (Tensor) – ground truth to compute mean dice metric. It must be one-hot format and first dim is batch. The values should be binarized.

  • include_background (bool) – whether to skip Dice computation on the first channel of the predicted output. Defaults to True.

Return type

Tensor

Returns

Dice scores per batch and per class, (shape [batch_size, num_classes]).

Raises

ValueError – when y_pred and y have different shapes.

class monai.metrics.DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. Input y_pred is compared with ground truth y. y_preds is expected to have binarized predictions and y should be in one-hot format. You can use suitable transforms in monai.transforms.post first to achieve binarized values. The include_background parameter can be set to False for an instance of DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. If the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. y_preds and y can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).

Parameters
  • include_background (bool) – whether to skip Dice computation on the first channel of the predicted output. Defaults to True.

  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans). Here not_nans count the number of not nans for the metric, thus its shape equals to the shape of the metric.

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

aggregate()[source]

Execute reduction logic for the output of compute_meandice.

Area under the ROC curve

monai.metrics.compute_roc_auc(y_pred, y, average=Average.MACRO)[source]

Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: sklearn.metrics.roc_auc_score.

Parameters
  • y_pred (Tensor) – input data to compute, typical classification model output. the first dim must be batch, if multi-classes, it must be in One-Hot format. for example: shape [16] or [16, 1] for a binary data, shape [16, 2] for 2 classes data.

  • y (Tensor) – ground truth to compute ROC AUC metric, the first dim must be batch. if multi-classes, it must be in One-Hot format. for example: shape [16] or [16, 1] for a binary data, shape [16, 2] for 2 classes data.

  • average (Union[Average, str]) –

    {"macro", "weighted", "micro", "none"} Type of averaging performed if not binary classification. Defaults to "macro".

    • "macro": calculate metrics for each label, and find their unweighted mean.

      This does not take label imbalance into account.

    • "weighted": calculate metrics for each label, and find their average,

      weighted by support (the number of true instances for each label).

    • "micro": calculate metrics globally by considering each element of the label

      indicator matrix as a label.

    • "none": the scores for each class are returned.

Raises
  • ValueError – When y_pred dimension is not one of [1, 2].

  • ValueError – When y dimension is not one of [1, 2].

  • ValueError – When average is not one of [“macro”, “weighted”, “micro”, “none”].

Note

ROCAUC expects y to be comprised of 0’s and 1’s. y_pred must be either prob. estimates or confidence values.

class monai.metrics.ROCAUCMetric(average=Average.MACRO)[source]

Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: sklearn.metrics.roc_auc_score. The input y_pred and y can be a list of channel-first Tensor or a batch-first Tensor.

Parameters

average (Union[Average, str]) –

{"macro", "weighted", "micro", "none"} Type of averaging performed if not binary classification. Defaults to "macro".

  • "macro": calculate metrics for each label, and find their unweighted mean.

    This does not take label imbalance into account.

  • "weighted": calculate metrics for each label, and find their average,

    weighted by support (the number of true instances for each label).

  • "micro": calculate metrics globally by considering each element of the label

    indicator matrix as a label.

  • "none": the scores for each class are returned.

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

aggregate()[source]

As AUC metric needs to execute on the overall data, so usually users accumulate y_pred and y of every iteration, then execute real computation and reduction on the accumulated data.

Confusion matrix

monai.metrics.get_confusion_matrix(y_pred, y, include_background=True)[source]

Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for each channel of each sample within the input batch. Where, B equals to the batch size and C equals to the number of classes that need to be computed.

Parameters
  • y_pred (Tensor) – input data to compute. It must be one-hot format and first dim is batch. The values should be binarized.

  • y (Tensor) – ground truth to compute the metric. It must be one-hot format and first dim is batch. The values should be binarized.

  • include_background (bool) – whether to skip metric computation on the first channel of the predicted output. Defaults to True.

Raises

ValueError – when y_pred and y have different shapes.

monai.metrics.compute_confusion_matrix_metric(metric_name, confusion_matrix)[source]

This function is used to compute confusion matrix related metric.

Parameters
  • metric_name (str) – ["sensitivity", "specificity", "precision", "negative predictive value", "miss rate", "fall out", "false discovery rate", "false omission rate", "prevalence threshold", "threat score", "accuracy", "balanced accuracy", "f1 score", "matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead.

  • confusion_matrix (Tensor) – Please see the doc string of the function get_confusion_matrix for more details.

Raises
  • ValueError – when the size of the last dimension of confusion_matrix is not 4.

  • NotImplementedError – when specify a not implemented metric_name.

class monai.metrics.ConfusionMatrixMetric(include_background=True, metric_name='hit_rate', compute_sample=False, reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute confusion matrix related metrics. This function supports to calculate all metrics mentioned in: Confusion matrix. It can support both multi-classes and multi-labels classification and segmentation tasks. y_preds is expected to have binarized predictions and y should be in one-hot format. You can use suitable transforms in monai.transforms.post first to achieve binarized values. The include_background parameter can be set to False for an instance to exclude the first category (channel index 0) which is by convention assumed to be background. If the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence.

Parameters
  • include_background (bool) – whether to skip metric computation on the first channel of the predicted output. Defaults to True.

  • metric_name (Union[Sequence[str], str]) – ["sensitivity", "specificity", "precision", "negative predictive value", "miss rate", "fall out", "false discovery rate", "false omission rate", "prevalence threshold", "threat score", "accuracy", "balanced accuracy", "f1 score", "matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead. Except for input only one metric, multiple metrics are also supported via input a sequence of metric names, such as (“sensitivity”, “precision”, “recall”), if compute_sample is True, multiple f and not_nans will be returned with the same order as input names when calling the class.

  • compute_sample (bool) – when reducing, if True, each sample’s metric will be computed based on each confusion matrix first. if False, compute reduction on the confusion matrices first, defaults to False.

  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns [(metric, not_nans), …]. If False, aggregate() returns [metric, …]. Here not_nans count the number of not nans for True Positive, False Positive, True Negative and False Negative. Its shape depends on the shape of the metric, and it has one more dimension with size 4. For example, if the shape of the metric is [3, 3], not_nans has the shape [3, 3, 4].

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

aggregate()[source]

Execute reduction for the confusion matrix values.

Hausdorff distance

monai.metrics.compute_hausdorff_distance(y_pred, y, include_background=False, distance_metric='euclidean', percentile=None, directed=False)[source]

Compute the Hausdorff distance.

Parameters
  • y_pred (Union[ndarray, Tensor]) – input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized.

  • y (Union[ndarray, Tensor]) – ground truth to compute mean the distance. It must be one-hot format and first dim is batch. The values should be binarized.

  • include_background (bool) – whether to skip distance computation on the first channel of the predicted output. Defaults to False.

  • distance_metric (str) – : ["euclidean", "chessboard", "taxicab"] the metric used to compute surface distance. Defaults to "euclidean".

  • percentile (Optional[float]) – an optional float number between 0 and 100. If specified, the corresponding percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to None.

  • directed (bool) – whether to calculate directed Hausdorff distance. Defaults to False.

monai.metrics.compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric='euclidean', percentile=None)[source]

This function is used to compute the directed Hausdorff distance.

class monai.metrics.HausdorffDistanceMetric(include_background=False, distance_metric='euclidean', percentile=None, directed=False, reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute Hausdorff Distance between two tensors. It can support both multi-classes and multi-labels tasks. It supports both directed and non-directed Hausdorff distance calculation. In addition, specify the percentile parameter can get the percentile of the distance. Input y_pred is compared with ground truth y. y_preds is expected to have binarized predictions and y should be in one-hot format. You can use suitable transforms in monai.transforms.post first to achieve binarized values. y_preds and y can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). The implementation refers to DeepMind’s implementation.

Parameters
  • include_background (bool) – whether to include distance computation on the first channel of the predicted output. Defaults to False.

  • distance_metric (str) – : ["euclidean", "chessboard", "taxicab"] the metric used to compute surface distance. Defaults to "euclidean".

  • percentile (Optional[float]) – an optional float number between 0 and 100. If specified, the corresponding percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to None.

  • directed (bool) – whether to calculate directed Hausdorff distance. Defaults to False.

  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans). Here not_nans count the number of not nans for the metric, thus its shape equals to the shape of the metric.

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

aggregate()[source]

Execute reduction logic for the output of compute_hausdorff_distance.

Average surface distance

monai.metrics.compute_average_surface_distance(y_pred, y, include_background=False, symmetric=False, distance_metric='euclidean')[source]

This function is used to compute the Average Surface Distance from y_pred to y under the default setting. In addition, if sets symmetric = True, the average symmetric surface distance between these two inputs will be returned. The implementation refers to DeepMind’s implementation.

Parameters
  • y_pred (Union[ndarray, Tensor]) – input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized.

  • y (Union[ndarray, Tensor]) – ground truth to compute mean the distance. It must be one-hot format and first dim is batch. The values should be binarized.

  • include_background (bool) – whether to skip distance computation on the first channel of the predicted output. Defaults to False.

  • symmetric (bool) – whether to calculate the symmetric average surface distance between seg_pred and seg_gt. Defaults to False.

  • distance_metric (str) – : ["euclidean", "chessboard", "taxicab"] the metric used to compute surface distance. Defaults to "euclidean".

class monai.metrics.SurfaceDistanceMetric(include_background=False, symmetric=False, distance_metric='euclidean', reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute Surface Distance between two tensors. It can support both multi-classes and multi-labels tasks. It supports both symmetric and asymmetric surface distance calculation. Input y_pred is compared with ground truth y. y_preds is expected to have binarized predictions and y should be in one-hot format. You can use suitable transforms in monai.transforms.post first to achieve binarized values. y_preds and y can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).

Parameters
  • include_background (bool) – whether to skip distance computation on the first channel of the predicted output. Defaults to False.

  • symmetric (bool) – whether to calculate the symmetric average surface distance between seg_pred and seg_gt. Defaults to False.

  • distance_metric (str) – : ["euclidean", "chessboard", "taxicab"] the metric used to compute surface distance. Defaults to "euclidean".

  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans). Here not_nans count the number of not nans for the metric, thus its shape equals to the shape of the metric.

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

aggregate()[source]

Execute reduction logic for the output of compute_average_surface_distance.

Mean squared error

class monai.metrics.MSEMetric(reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute Mean Squared Error between two tensors using function:

\[\operatorname {MSE}\left(Y, \hat{Y}\right) =\frac {1}{n}\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}.\]

More info: https://en.wikipedia.org/wiki/Mean_squared_error

Input y_pred is compared with ground truth y. Both y_pred and y are expected to be real-valued, where y_pred is output from a regression model.

Parameters
  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans).

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

Mean absolute error

class monai.metrics.MAEMetric(reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute Mean Absolute Error between two tensors using function:

\[\operatorname {MAE}\left(Y, \hat{Y}\right) =\frac {1}{n}\sum _{i=1}^{n}\left|y_i-\hat{y_i}\right|.\]

More info: https://en.wikipedia.org/wiki/Mean_absolute_error

Input y_pred is compared with ground truth y. Both y_pred and y are expected to be real-valued, where y_pred is output from a regression model.

Parameters
  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans).

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

Root mean squared error

class monai.metrics.RMSEMetric(reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute Root Mean Squared Error between two tensors using function:

\[\operatorname {RMSE}\left(Y, \hat{Y}\right) ={ \sqrt{ \frac {1}{n}\sum _{i=1}^{n}\left(y_i-\hat{y_i}\right)^2 } } \ = \sqrt {\operatorname{MSE}\left(Y, \hat{Y}\right)}.\]

More info: https://en.wikipedia.org/wiki/Root-mean-square_deviation

Input y_pred is compared with ground truth y. Both y_pred and y are expected to be real-valued, where y_pred is output from a regression model.

Parameters
  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans).

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

Peak signal to noise ratio

class monai.metrics.PSNRMetric(max_val, reduction=MetricReduction.MEAN, get_not_nans=False)[source]

Compute Peak Signal To Noise Ratio between two tensors using function:

\[\operatorname{PSNR}\left(Y, \hat{Y}\right) = 20 \cdot \log_{10} \left({\mathit{MAX}}_Y\right) \ -10 \cdot \log_{10}\left(\operatorname{MSE\left(Y, \hat{Y}\right)}\right)\]

More info: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

Help taken from: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/image_ops_impl.py line 4139

Input y_pred is compared with ground truth y. Both y_pred and y are expected to be real-valued, where y_pred is output from a regression model.

Parameters
  • max_val (Union[int, float]) – The dynamic range of the images/volumes (i.e., the difference between the maximum and the minimum allowed values e.g. 255 for a uint8 image).

  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, will not do reduction.

  • get_not_nans (bool) – whether to return the not_nans count, if True, aggregate() returns (metric, not_nans).

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

Cumulative average

class monai.metrics.CumulativeAverage[source]

Cumulatively record data value and aggregate for the average value. It supports single class or multi-class data, for example, value can be 0.44 (a loss value) or [0.3, 0.4] (metrics of two classes). It also supports distributed data parallel, sync data when aggregating. For example, recording loss values and compute the overall average value in every 5 iterations:

average = CumulativeAverage()
for i, d in enumerate(dataloader):
    loss = ...
    average.append(loss)
    if i % 5 == 0:
        print(f"cumulative average of loss: {average.aggregate()}")
average.reset()

Initialize the internal buffers. self._buffers are local buffers, they are not usually used directly. self._sync_buffers are the buffers with all the results across all the nodes.

aggregate()[source]

Sync data from all the ranks and compute the average value with previous sum value.

reset()[source]

Reset all the running status, including buffers, sum, not nans count, etc.

Utilities

monai.metrics.utils.do_metric_reduction(f, reduction=MetricReduction.MEAN)[source]

This function is to do the metric reduction for calculated not-nan metrics of each sample’s each class. The function also returns not_nans, which counts the number of not nans for the metric.

Parameters
  • f (Tensor) – a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class.

  • reduction (Union[MetricReduction, str]) – define the mode to reduce metrics, will only execute reduction on not-nan values, available reduction modes: {"none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"}, default to "mean". if “none”, return the input f tensor and not_nans.

  • "mean". (Define the mode to reduce computation result of 1 batch data. Defaults to) –

Raises

ValueError – When reduction is not one of [“mean”, “sum”, “mean_batch”, “sum_batch”, “mean_channel”, “sum_channel” “none”].

monai.metrics.utils.get_mask_edges(seg_pred, seg_gt, label_idx=1, crop=True)[source]

Do binary erosion and use XOR for input to get the edges. This function is helpful to further calculate metrics such as Average Surface Distance and Hausdorff Distance. The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using label_idx.

scipy’s binary erosion is used to to calculate the edges of the binary labelfield.

In order to improve the computing efficiency, before getting the edges, the images can be cropped and only keep the foreground if not specifies crop = False.

We require that images are the same size, and assume that they occupy the same space (spacing, orientation, etc.).

Parameters
  • seg_pred (Union[ndarray, Tensor]) – the predicted binary or labelfield image.

  • seg_gt (Union[ndarray, Tensor]) – the actual binary or labelfield image.

  • label_idx (int) – for labelfield images, convert to binary with seg_pred = seg_pred == label_idx.

  • crop (bool) – crop input images and only keep the foregrounds. In order to maintain two inputs’ shapes, here the bounding box is achieved by (seg_pred | seg_gt) which represents the union set of two images. Defaults to True.

Return type

Tuple[ndarray, ndarray]

monai.metrics.utils.get_surface_distance(seg_pred, seg_gt, distance_metric='euclidean')[source]

This function is used to compute the surface distances from seg_pred to seg_gt.

Parameters
  • seg_pred (ndarray) – the edge of the predictions.

  • seg_gt (ndarray) – the edge of the ground truth.

  • distance_metric (str) –

    : ["euclidean", "chessboard", "taxicab"] the metric used to compute surface distance. Defaults to "euclidean".

    • "euclidean", uses Exact Euclidean distance transform.

    • "chessboard", uses chessboard metric in chamfer type of transform.

    • "taxicab", uses taxicab metric in chamfer type of transform.

Note

If seg_pred or seg_gt is all 0, may result in nan/inf distance.

Return type

ndarray

monai.metrics.utils.ignore_background(y_pred, y)[source]

This function is used to remove background (the first channel) for y_pred and y.

Parameters
  • y_pred (Union[ndarray, Tensor]) – predictions. As for classification tasks, y_pred should has the shape [BN] where N is larger than 1. As for segmentation tasks, the shape should be [BNHW] or [BNHWD].

  • y (Union[ndarray, Tensor]) – ground truth, the first dim is batch.