Source code for monai.metrics.occlusion_sensitivity

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Sequence
from functools import partial
from typing import Optional, Union

import numpy as np
import torch
import torch.nn as nn

try:
    from tqdm import trange

    trange = partial(trange, desc="Computing occlusion sensitivity")
except (ImportError, AttributeError):
    trange = range


def _check_input_image(image):
    """Check that the input image is as expected."""
    # Only accept batch size of 1
    if image.shape[0] > 1:
        raise RuntimeError("Expected batch size of 1.")
    return image


def _check_input_label(label, image):
    """Check that the input label is as expected."""
    # If necessary turn the label into a 1-element tensor
    if isinstance(label, int):
        label = torch.tensor([[label]], dtype=torch.int64).to(image.device)
    # If the label is a tensor, make sure  there's only 1 element
    elif label.numel() != image.shape[0]:
        raise RuntimeError("Expected as many labels as batches.")
    return label


def _check_input_bounding_box(b_box, im_shape):
    """Check that the bounding box (if supplied) is as expected."""
    # If no bounding box has been supplied, set min and max to None
    if b_box is None:
        b_box_min = b_box_max = None

    # Bounding box has been supplied
    else:
        # Should be twice as many elements in `b_box` as `im_shape`
        if len(b_box) != 2 * len(im_shape):
            raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)")

        # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively.
        b_box_min = np.array(b_box[::2])
        b_box_max = np.array(b_box[1::2])
        b_box_min[b_box_min < 0] = 0
        b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1
        # Check all max's are < im_shape
        if np.any(b_box_max >= im_shape):
            raise ValueError("Max bounding box should be < image size for all values")
        # Check all min's are <= max's
        if np.any(b_box_min > b_box_max):
            raise ValueError("Min bounding box should be <= max for all values")

    return b_box_min, b_box_max


def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im):
    """For given number of images, get probability of predicting
    a given label. Append to previous evaluations."""
    batch_images = torch.cat(batch_images, dim=0)
    batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device)
    scores = model(batch_images).detach().gather(1, batch_ids)
    return torch.cat((sensitivity_im, scores))


[docs]def compute_occlusion_sensitivity( model: nn.Module, image: torch.Tensor, label: Union[int, torch.Tensor], pad_val: float = 0.0, margin: Union[int, Sequence] = 2, n_batch: int = 128, b_box: Optional[Sequence] = None, stride: Union[int, Sequence] = 1, upsample_mode: str = "nearest", ) -> np.ndarray: """ This function computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity, we mean how the probability of a given prediction changes as the occluded section of an image changes. This can be useful to understand why a network is making certain decisions. The result is given as ``baseline`` (the probability of a certain output) minus the probability of the output with the occluded area. Therefore, higher values in the output image mean there was a greater the drop in certainty, indicating the occluded region was more important in the decision process. See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 Args: model: classification model to use for inference image: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. label: classification label to check for changes (normally the true label, but doesn't have to be) pad_val: when occluding part of the image, which values should we put in the image? margin: we'll create a cuboid/cube around the voxel to be occluded. if ``margin==2``, then we'll create a cube that is +/- 2 voxels in all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` can be supplied to have a margin of different sizes (i.e., create a cuboid). n_batch: number of images in a batch before inference. b_box: Bounding box on which to perform the analysis. The output image will also match in size. There should be a minimum and maximum for all dimensions except batch: ``[min1, max1, min2, max2,...]``. * By default, the whole image will be used. Decreasing the size will speed the analysis up, which might be useful for larger images. * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. stride: Stride for performing occlusions. Can be single value or sequence (for varying stride in the different directions). Should be >= 1. upsample_mode: If stride != 1 is used, we'll upsample such that the size of the voxels in the output image match the input. Upsampling is done with ``torch.nn.Upsample``, and mode can be set to: * ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear`` * default is ``nearest``. Returns: Numpy array. If no bounding box is supplied, this will be the same size as the input image. If a bounding box is used, the output image will be cropped to this size. """ # Check input arguments image = _check_input_image(image) label = _check_input_label(label, image) im_shape = np.array(image.shape[1:]) b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) # Get baseline probability baseline = model(image).detach()[0, label].item() # Create some lists batch_images = [] batch_ids = [] sensitivity_im = torch.empty(0, dtype=torch.float32, device=image.device) # If no bounding box supplied, output shape is same as input shape. # If bounding box is present, shape is max - min + 1 output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 # Calculate the downsampled shape if not isinstance(stride, Sequence): stride_np = np.full_like(im_shape, stride, dtype=np.int32) stride_np[0] = 1 # always do stride 1 in channel dimension else: # Convert to numpy array and check dimensions match stride_np = np.array(stride, dtype=np.int32) if stride_np.size != im_shape.size: raise ValueError("Sizes of image shape and stride should match.") # Obviously if stride = 1, downsampled_im_shape == output_im_shape downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 num_required_predictions = np.prod(downsampled_im_shape) # Loop 1D over image for i in trange(num_required_predictions): # Get corresponding ND index idx = np.unravel_index(i, downsampled_im_shape) # Multiply by stride idx *= stride_np # If a bounding box is being used, we need to add on # the min to shift to start of region of interest if b_box_min is not None: idx += b_box_min # Get min and max index of box to occlude min_idx = [max(0, i - margin) for i in idx] max_idx = [min(j, i + margin) for i, j in zip(idx, im_shape)] # Clone and replace target area with `pad_val` occlu_im = image.clone() occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val # Add to list batch_images.append(occlu_im) batch_ids.append(label) # Once the batch is complete (or on last iteration) if len(batch_images) == n_batch or i == num_required_predictions - 1: # Do the predictions and append to sensitivity map sensitivity_im = _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im) # Clear lists batch_images = [] batch_ids = [] # Subtract from baseline sensitivity_im = baseline - sensitivity_im # Reshape to match downsampled image sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)) # If necessary, upsample if np.any(stride_np != 1): output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode) sensitivity_im = upsampler(sensitivity_im.unsqueeze(0)) # Convert tensor to numpy sensitivity_im = sensitivity_im.cpu().numpy() # Squeeze and return return np.squeeze(sensitivity_im)