# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from monai.networks.utils import eval_mode
from monai.visualize.visualizer import default_upsampler
try:
from tqdm import trange
trange = partial(trange, desc="Computing occlusion sensitivity")
except (ImportError, AttributeError):
trange = range
# For stride two (for example),
# if input array is: |0|1|2|3|4|5|6|7|
# downsampled output is: | 0 | 1 | 2 | 3 |
# So the upsampling should do it by the corners of the image, not their centres
default_upsampler = partial(default_upsampler, align_corners=True)
def _check_input_image(image):
"""Check that the input image is as expected."""
# Only accept batch size of 1
if image.shape[0] > 1:
raise RuntimeError("Expected batch size of 1.")
def _check_input_bounding_box(b_box, im_shape):
"""Check that the bounding box (if supplied) is as expected."""
# If no bounding box has been supplied, set min and max to None
if b_box is None:
b_box_min = b_box_max = None
# Bounding box has been supplied
else:
# Should be twice as many elements in `b_box` as `im_shape`
if len(b_box) != 2 * len(im_shape):
raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)")
# If any min's or max's are -ve, set them to 0 and im_shape-1, respectively.
b_box_min = np.array(b_box[::2])
b_box_max = np.array(b_box[1::2])
b_box_min[b_box_min < 0] = 0
b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1
# Check all max's are < im_shape
if np.any(b_box_max >= im_shape):
raise ValueError("Max bounding box should be < image size for all values")
# Check all min's are <= max's
if np.any(b_box_min > b_box_max):
raise ValueError("Min bounding box should be <= max for all values")
return b_box_min, b_box_max
def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims):
"""Infer given images. Append to previous evaluations. Store each class separately."""
batch_images = torch.cat(batch_images, dim=0)
scores = model(batch_images).detach()
for i in range(scores.shape[1]):
sensitivity_ims[i] = torch.cat((sensitivity_ims[i], scores[:, i]))
return sensitivity_ims
def _get_as_np_array(val, numel):
# If not a sequence, then convert scalar to numpy array
if not isinstance(val, Sequence):
out = np.full(numel, val, dtype=np.int32)
out[0] = 1 # mask_size and stride always 1 in channel dimension
else:
# Convert to numpy array and check dimensions match
out = np.array(val, dtype=np.int32)
# Add stride of 1 to the channel direction (since user input was only for spatial dimensions)
out = np.insert(out, 0, 1)
if out.size != numel:
raise ValueError(
"If supplying stride/mask_size as sequence, number of elements should match number of spatial dimensions."
)
return out
[docs]class OcclusionSensitivity:
"""
This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity,
we mean how the probability of a given prediction changes as the occluded section of an image changes. This can be
useful to understand why a network is making certain decisions.
As important parts of the image are occluded, the probability of classifying the image correctly will decrease.
Hence, more negative values imply the corresponding occluded volume was more important in the decision process.
Two ``torch.Tensor`` will be returned by the ``__call__`` method: an occlusion map and an image of the most probable
class. Both images will be cropped if a bounding box used, but voxel sizes will always match the input.
The occlusion map shows the inference probabilities when the corresponding part of the image is occluded. Hence,
more -ve values imply that region was important in the decision process. The map will have shape ``BCHW(D)N``,
where ``N`` is the number of classes to be inferred by the network. Hence, the occlusion for class ``i`` can
be seen with ``map[...,i]``.
The most probable class is an image of the probable class when the corresponding part of the image is occluded
(equivalent to ``occ_map.argmax(dim=-1)``).
See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via
Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74.
Examples:
.. code-block:: python
# densenet 2d
from monai.networks.nets import DenseNet121
from monai.visualize import OcclusionSensitivity
model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
occ_sens = OcclusionSensitivity(nn_module=model_2d)
occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), b_box=[-1, -1, 2, 40, 1, 62])
# densenet 3d
from monai.networks.nets import DenseNet
from monai.visualize import OcclusionSensitivity
model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))
occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=3)
occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), b_box=[-1, -1, 1, 3, -1, -1, -1, -1])
See Also:
- :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.`
"""
def __init__(
self,
nn_module: nn.Module,
pad_val: Optional[float] = None,
mask_size: Union[int, Sequence] = 15,
n_batch: int = 128,
stride: Union[int, Sequence] = 1,
upsampler: Optional[Callable] = default_upsampler,
verbose: bool = True,
) -> None:
"""Occlusion sensitivity constructor.
Args:
nn_module: Classification model to use for inference
pad_val: When occluding part of the image, which values should we put
in the image? If ``None`` is used, then the average of the image will be used.
mask_size: Size of box to be occluded, centred on the central voxel. To ensure that the occluded area
is correctly centred, ``mask_size`` and ``stride`` should both be odd or even.
n_batch: Number of images in a batch for inference.
stride: Stride in spatial directions for performing occlusions. Can be single
value or sequence (for varying stride in the different directions).
Should be >= 1. Striding in the channel direction will always be 1.
upsampler: An upsampling method to upsample the output image. Default is
N-dimensional linear (bilinear, trilinear, etc.) depending on num spatial
dimensions of input.
verbose: Use ``tdqm.trange`` output (if available).
"""
self.nn_module = nn_module
self.upsampler = upsampler
self.pad_val = pad_val
self.mask_size = mask_size
self.n_batch = n_batch
self.stride = stride
self.verbose = verbose
def _compute_occlusion_sensitivity(self, x, b_box):
# Get bounding box
im_shape = np.array(x.shape[1:])
b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape)
# Get the number of prediction classes
num_classes = self.nn_module(x).numel()
# If pad val not supplied, get the mean of the image
pad_val = x.mean() if self.pad_val is None else self.pad_val
# List containing a batch of images to be inferred
batch_images = []
# List of sensitivity images, one for each inferred class
sensitivity_ims = num_classes * [torch.empty(0, dtype=torch.float32, device=x.device)]
# If no bounding box supplied, output shape is same as input shape.
# If bounding box is present, shape is max - min + 1
output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1
# Get the stride and mask_size as numpy arrays
self.stride = _get_as_np_array(self.stride, len(im_shape))
self.mask_size = _get_as_np_array(self.mask_size, len(im_shape))
# For each dimension, ...
for o, s in zip(output_im_shape, self.stride):
# if the size is > 1, then check that the stride is a factor of the output image shape
if o > 1 and o % s != 0:
raise ValueError(
"Stride should be a factor of the image shape. Im shape "
+ f"(taking bounding box into account): {output_im_shape}, stride: {self.stride}"
)
# to ensure the occluded area is nicely centred if stride is even, ensure that so is the mask_size
if np.any(self.mask_size % 2 != self.stride % 2):
raise ValueError(
"Stride and mask size should both be odd or even (element-wise). "
+ f"``stride={self.stride}``, ``mask_size={self.mask_size}``"
)
downsampled_im_shape = (output_im_shape / self.stride).astype(np.int32)
downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1
num_required_predictions = np.prod(downsampled_im_shape)
# Get bottom left and top right corners of occluded region
lower_corner = (self.stride - self.mask_size) // 2
upper_corner = (self.stride + self.mask_size) // 2
# Loop 1D over image
verbose_range = trange if self.verbose else range
for i in verbose_range(num_required_predictions):
# Get corresponding ND index
idx = np.unravel_index(i, downsampled_im_shape)
# Multiply by stride
idx *= self.stride
# If a bounding box is being used, we need to add on
# the min to shift to start of region of interest
if b_box_min is not None:
idx += b_box_min
# Get min and max index of box to occlude (and make sure it's in bounds)
min_idx = np.maximum(idx + lower_corner, 0)
max_idx = np.minimum(idx + upper_corner, im_shape)
# Clone and replace target area with `pad_val`
occlu_im = x.detach().clone()
occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val
# Add to list
batch_images.append(occlu_im)
# Once the batch is complete (or on last iteration)
if len(batch_images) == self.n_batch or i == num_required_predictions - 1:
# Do the predictions and append to sensitivity maps
sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims)
# Clear lists
batch_images = []
# Reshape to match downsampled image, and unsqueeze to add batch dimension back in
for i in range(num_classes):
sensitivity_ims[i] = sensitivity_ims[i].reshape(tuple(downsampled_im_shape)).unsqueeze(0)
return sensitivity_ims, output_im_shape
def __call__( # type: ignore
self,
x: torch.Tensor,
b_box: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Image to use for inference. Should be a tensor consisting of 1 batch.
b_box: Bounding box on which to perform the analysis. The output image will be limited to this size.
There should be a minimum and maximum for all dimensions except batch: ``[min1, max1, min2, max2,...]``.
* By default, the whole image will be used. Decreasing the size will speed the analysis up, which might
be useful for larger images.
* Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``.
* Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension.
Returns:
* Occlusion map:
* Shows the inference probabilities when the corresponding part of the image is occluded.
Hence, more -ve values imply that region was important in the decision process.
* The map will have shape ``BCHW(D)N``, where N is the number of classes to be inferred by the
network. Hence, the occlusion for class ``i`` can be seen with ``map[...,i]``.
* Most probable class:
* The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``).
Both images will be cropped if a bounding box used, but voxel sizes will always match the input.
"""
with eval_mode(self.nn_module):
# Check input arguments
_check_input_image(x)
# Generate sensitivity images
sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box)
# Loop over image for each classification
for i, sens_i in enumerate(sensitivity_ims_list):
# upsample
if self.upsampler is not None:
if len(sens_i.shape) != len(x.shape):
raise AssertionError
if np.any(sens_i.shape != x.shape):
img_spatial = tuple(output_im_shape[1:])
sensitivity_ims_list[i] = self.upsampler(img_spatial)(sens_i)
# Convert list of tensors to tensor
sensitivity_ims = torch.stack(sensitivity_ims_list, dim=-1)
# The most probable class is the max in the classification dimension (last)
most_probable_class = sensitivity_ims.argmax(dim=-1)
return sensitivity_ims, most_probable_class