Source code for monai.metrics.utils

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Tuple, Union

import numpy as np
import torch

from monai.transforms.croppad.array import SpatialCrop
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import MetricReduction, convert_data_type, look_up_option, optional_import

binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion")
distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt")
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")

__all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance", "is_binary_tensor"]


[docs]def ignore_background(y_pred: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): """ This function is used to remove background (the first channel) for `y_pred` and `y`. Args: y_pred: predictions. As for classification tasks, `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, the shape should be [BNHW] or [BNHWD]. y: ground truth, the first dim is batch. """ y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred return y_pred, y
[docs]def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN): """ This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class. The function also returns `not_nans`, which counts the number of not nans for the metric. Args: f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. reduction: define the mode to reduce metrics, will only apply reduction on `not-nan` values, available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", return the input f tensor and not_nans. Raises: ValueError: When ``reduction`` is not one of ["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"]. """ # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) not_nans = (~nans).float() t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: return f, not_nans f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average not_nans = (not_nans > 0).float().sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: not_nans = not_nans.sum(dim=[0, 1]) f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims elif reduction == MetricReduction.MEAN_BATCH: not_nans = not_nans.sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM_BATCH: not_nans = not_nans.sum(dim=0) f = f.sum(dim=0) # the batch sum elif reduction == MetricReduction.MEAN_CHANNEL: not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average elif reduction == MetricReduction.SUM_CHANNEL: not_nans = not_nans.sum(dim=1) f = f.sum(dim=1) # the channel sum elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' ) return f, not_nans
[docs]def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Do binary erosion and use XOR for input to get the edges. This function is helpful to further calculate metrics such as Average Surface Distance and Hausdorff Distance. The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. `scipy`'s binary erosion is used to calculate the edges of the binary labelfield. In order to improve the computing efficiency, before getting the edges, the images can be cropped and only keep the foreground if not specifies ``crop = False``. We require that images are the same size, and assume that they occupy the same space (spacing, orientation, etc.). Args: seg_pred: the predicted binary or labelfield image. seg_gt: the actual binary or labelfield image. label_idx: for labelfield images, convert to binary with `seg_pred = seg_pred == label_idx`. crop: crop input images and only keep the foregrounds. In order to maintain two inputs' shapes, here the bounding box is achieved by ``(seg_pred | seg_gt)`` which represents the union set of two images. Defaults to ``True``. """ # Get both labelfields as np arrays if isinstance(seg_pred, torch.Tensor): seg_pred = seg_pred.detach().cpu().numpy() if isinstance(seg_gt, torch.Tensor): seg_gt = seg_gt.detach().cpu().numpy() if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") # If not binary images, convert them if seg_pred.dtype != bool: seg_pred = seg_pred == label_idx if seg_gt.dtype != bool: seg_gt = seg_gt == label_idx if crop: if not np.any(seg_pred | seg_gt): return np.zeros_like(seg_pred), np.zeros_like(seg_gt) channel_dim = 0 seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] # Do binary erosion and use XOR to get edges edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt return edges_pred, edges_gt
[docs]def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metric: str = "euclidean") -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. Args: seg_pred: the edge of the predictions. seg_gt: the edge of the ground truth. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. - ``"euclidean"``, uses Exact Euclidean distance transform. - ``"chessboard"``, uses `chessboard` metric in chamfer type of transform. - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. Note: If seg_pred or seg_gt is all 0, may result in nan/inf distance. """ if not np.any(seg_gt): dis = np.inf * np.ones_like(seg_gt) else: if not np.any(seg_pred): dis = np.inf * np.ones_like(seg_gt) return np.asarray(dis[seg_gt]) if distance_metric == "euclidean": dis = distance_transform_edt(~seg_gt) elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(~seg_gt, metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") return np.asarray(dis[seg_pred])
[docs]def is_binary_tensor(input: torch.Tensor, name: str): """Determines whether the input tensor is torch binary tensor or not. Args: input (torch.Tensor): tensor to validate. name (str): name of the tensor being checked. Raises: ValueError: if `input` is not a PyTorch Tensor. Returns: Union[str, None]: warning message, if the tensor is not binary. Otherwise, None. """ if not isinstance(input, torch.Tensor): raise ValueError(f"{name} must be of type PyTorch Tensor.") if not torch.all(input.byte() == input) or input.max() > 1 or input.min() < 0: warnings.warn(f"{name} should be a binarized tensor.")