Source code for monai.metrics.meandice

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Callable, Optional, Union

import torch

from monai.networks import one_hot
from monai.utils import MetricReduction


[docs]class DiceMetric: """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. Input logits `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). Axis N of `y_preds` is expected to have logit predictions for each class rather than being image channels, while the same axis of `y` can be 1 or N (one-hot format). The `include_background` class attribute can be set to False for an instance of DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. If the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. Args: include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to True. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using a combination of argmax and to_onehot. Defaults to False. sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``. for example: `other_act = torch.tanh`. logit_thresh: the threshold value used to convert (for example, after sigmoid if `sigmoid=True`) `y_pred` into a binary matrix. Defaults to 0.5. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: ValueError: When ``sigmoid=True`` and ``other_act is not None``. Incompatible values. """ def __init__( self, include_background: bool = True, to_onehot_y: bool = False, mutually_exclusive: bool = False, sigmoid: bool = False, other_act: Optional[Callable] = None, logit_thresh: float = 0.5, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: super().__init__() if sigmoid and other_act is not None: raise ValueError("Incompatible values: ``sigmoid=True`` and ``other_act is not None``.") self.include_background = include_background self.to_onehot_y = to_onehot_y self.mutually_exclusive = mutually_exclusive self.sigmoid = sigmoid self.other_act = other_act self.logit_thresh = logit_thresh self.reduction: MetricReduction = MetricReduction(reduction) self.not_nans: Optional[torch.Tensor] = None # keep track for valid elements in the batch def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: y_pred: input data to compute, typical segmentation model output. it must be one-hot format and first dim is batch. y: ground truth to compute mean dice metric, the first dim is batch. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"]. """ # compute dice (BxC) for each channel for each batch f = compute_meandice( y_pred=y_pred, y=y, include_background=self.include_background, to_onehot_y=self.to_onehot_y, mutually_exclusive=self.mutually_exclusive, sigmoid=self.sigmoid, other_act=self.other_act, logit_thresh=self.logit_thresh, ) # some dice elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) not_nans = (~nans).float() f[nans] = 0 t_zero = torch.zeros(1, device=f.device, dtype=torch.float) if self.reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average not_nans = (not_nans > 0).float().sum() f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # batch average elif self.reduction == MetricReduction.SUM: not_nans = not_nans.sum() f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == MetricReduction.MEAN_BATCH: not_nans = not_nans.sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif self.reduction == MetricReduction.SUM_BATCH: not_nans = not_nans.sum(dim=0) f = f.sum(dim=0) # the batch sum elif self.reduction == MetricReduction.MEAN_CHANNEL: not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average elif self.reduction == MetricReduction.SUM_CHANNEL: not_nans = not_nans.sum(dim=1) f = f.sum(dim=1) # the channel sum elif self.reduction == MetricReduction.NONE: pass else: raise ValueError( f"Unsupported reduction: {self.reduction}, available options are " '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' ) # save not_nans since we may need it later to know how many elements were valid self.not_nans = not_nans return f
[docs]def compute_meandice( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, to_onehot_y: bool = False, mutually_exclusive: bool = False, sigmoid: bool = False, other_act: Optional[Callable] = None, logit_thresh: float = 0.5, ) -> torch.Tensor: """Computes Dice score metric from full size Tensor and collects average. Args: y_pred: input data to compute, typical segmentation model output. it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. y: ground truth to compute mean dice metric, the first dim is batch. example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to True. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using a combination of argmax and to_onehot. Defaults to False. sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``. for example: `other_act = torch.tanh`. logit_thresh: the threshold value used to convert (for example, after sigmoid if `sigmoid=True`) `y_pred` into a binary matrix. Defaults to 0.5. Raises: ValueError: When ``sigmoid=True`` and ``other_act is not None``. Incompatible values. TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When ``sigmoid=True`` and ``mutually_exclusive=True``. Incompatible values. Returns: Dice scores per batch and per class, (shape [batch_size, n_classes]). Note: This method provides two options to convert `y_pred` into a binary matrix (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` (optionally with a ``sigmoid`` function before thresholding). """ n_classes = y_pred.shape[1] n_len = len(y_pred.shape) if sigmoid and other_act is not None: raise ValueError("Incompatible values: sigmoid=True and other_act is not None.") if sigmoid: y_pred = y_pred.float().sigmoid() if other_act is not None: if not callable(other_act): raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") y_pred = other_act(y_pred) if n_classes == 1: if mutually_exclusive: warnings.warn("y_pred has only one class, mutually_exclusive=True ignored.") if to_onehot_y: warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") if not include_background: warnings.warn("y_pred has only one channel, include_background=False ignored.") # make both y and y_pred binary y_pred = (y_pred >= logit_thresh).float() y = (y > 0).float() else: # multi-channel y_pred # make both y and y_pred binary if mutually_exclusive: if sigmoid: raise ValueError("Incompatible values: sigmoid=True and mutually_exclusive=True.") y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = one_hot(y_pred, num_classes=n_classes) else: y_pred = (y_pred >= logit_thresh).float() if to_onehot_y: y = one_hot(y, num_classes=n_classes) if not include_background: y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred assert y.shape == y_pred.shape, "Ground truth one-hot has differing shape (%r) from source (%r)" % ( y.shape, y_pred.shape, ) y = y.float() y_pred = y_pred.float() # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, n_len)) intersection = torch.sum(y * y_pred, dim=reduce_axis) y_o = torch.sum(y, reduce_axis) y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) return f # returns array of Dice shape: [batch, n_classes]