Source code for monai.visualize.class_activation_maps

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Callable, Dict, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.transforms import ScaleIntensity
from monai.utils import ensure_tuple, get_torch_version_tuple
from monai.visualize.visualizer import default_upsampler

__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"]


[docs]def default_normalizer(x) -> np.ndarray: """ A linear intensity scaling by mapping the (min, max) to (1, 0). N.B.: This will flip magnitudes (i.e., smallest will become biggest and vice versa). """ if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() scaler = ScaleIntensity(minv=1.0, maxv=0.0) x = [scaler(x) for x in x] return np.stack(x, axis=0)
[docs]class ModelWithHooks: """ A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. """ def __init__( self, nn_module, target_layer_names: Union[str, Sequence[str]], register_forward: bool = False, register_backward: bool = False, ): """ Args: nn_module: the model to be wrapped. target_layer_names: the names of the layer to cache. register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. """ self.model = nn_module self.target_layers = ensure_tuple(target_layer_names) self.gradients: Dict[str, torch.Tensor] = {} self.activations: Dict[str, torch.Tensor] = {} self.score = None self.class_idx = None self.register_backward = register_backward self.register_forward = register_forward _registered = [] for name, mod in nn_module.named_modules(): if name not in self.target_layers: continue _registered.append(name) if self.register_backward: if get_torch_version_tuple() < (1, 8): mod.register_backward_hook(self.backward_hook(name)) else: if "inplace" in mod.__dict__ and mod.__dict__["inplace"]: # inplace=True causes errors for register_full_backward_hook mod.__dict__["inplace"] = False mod.register_full_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) if len(_registered) != len(self.target_layers): warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") def backward_hook(self, name): def _hook(_module, _grad_input, grad_output): self.gradients[name] = grad_output[0] return _hook def forward_hook(self, name): def _hook(_module, _input, output): self.activations[name] = output return _hook
[docs] def get_layer(self, layer_id: Union[str, Callable]): """ Args: layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`, this method will return the module `self.model.fc`. Returns: a submodule from self.model. """ if callable(layer_id): return layer_id(self.model) if isinstance(layer_id, str): for name, mod in self.model.named_modules(): if name == layer_id: return mod raise NotImplementedError(f"Could not find {layer_id}.")
def class_score(self, logits, class_idx): return logits[:, class_idx].squeeze() def __call__(self, x, class_idx=None, retain_graph=False): train = self.model.training self.model.eval() logits = self.model(x) self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx acti, grad = None, None if self.register_forward: acti = tuple(self.activations[layer] for layer in self.target_layers) if self.register_backward: self.score = self.class_score(logits, self.class_idx) self.model.zero_grad() self.score.sum().backward(retain_graph=retain_graph) grad = tuple(self.gradients[layer] for layer in self.target_layers) if train: self.model.train() return logits, acti, grad def get_wrapped_net(self): return self.model
class CAMBase: """ Base class for CAM methods. """ def __init__( self, nn_module: nn.Module, target_layers: str, upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, register_backward: bool = True, ) -> None: # Convert to model with hooks if necessary if not isinstance(nn_module, ModelWithHooks): self.nn_module = ModelWithHooks( nn_module, target_layers, register_forward=True, register_backward=register_backward ) else: self.nn_module = nn_module self.upsampler = upsampler self.postprocessing = postprocessing def feature_map_size(self, input_size, device="cpu", layer_idx=-1): """ Computes the actual feature map size given `nn_module` and the target_layer name. Args: input_size: shape of the input tensor device: the device used to initialise the input tensor layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. Returns: shape of the actual feature map. """ return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape def compute_map(self, x, class_idx=None, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. Args: x: input to `nn_module`. class_idx: index of the class to be visualized. Default to `None` (computing `class_idx` from `argmax`) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. Returns: activation maps (raw outputs without upsampling/post-processing.) """ raise NotImplementedError() def _upsample_and_post_process(self, acti_map, x): # upsampling and postprocessing if self.upsampler: img_spatial = x.shape[2:] acti_map = self.upsampler(img_spatial)(acti_map) if self.postprocessing: acti_map = self.postprocessing(acti_map) return acti_map def __call__(self): raise NotImplementedError()
[docs]class CAM(CAMBase): """ Compute class activation map from the last fully-connected layers before the spatial pooling. This implementation is based on: Zhou et al., Learning Deep Features for Discriminative Localization. CVPR '16, https://arxiv.org/abs/1512.04150 Examples .. code-block:: python # densenet 2d from monai.networks.nets import DenseNet121 from monai.visualize import CAM model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out") result = cam(x=torch.rand((1, 1, 48, 64))) # resnet 2d from monai.networks.nets import se_resnet50 from monai.visualize import CAM model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear") result = cam(x=torch.rand((2, 3, 48, 64))) N.B.: To help select the target layer, it may be useful to list all layers: .. code-block:: python for name, _ in model.named_modules(): print(name) See Also: - :py:class:`monai.visualize.class_activation_maps.GradCAM` """ def __init__( self, nn_module: nn.Module, target_layers: str, fc_layers: Union[str, Callable] = "fc", upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, ) -> None: """ Args: nn_module: the model to be visualized target_layers: name of the model layer to generate the feature map. fc_layers: a string or a callable used to get fully-connected weights to compute activation map from the target_layers (without pooling). and evaluate it at every spatial location. upsampler: An upsampling method to upsample the output image. Default is N dimensional linear (bilinear, trilinear, etc.) depending on num spatial dimensions of input. postprocessing: a callable that applies on the upsampled output image. Default is normalizing between min=1 and max=0 (i.e., largest input will become 0 and smallest input will become 1). """ super().__init__( nn_module=nn_module, target_layers=target_layers, upsampler=upsampler, postprocessing=postprocessing, register_backward=False, ) self.fc_layers = fc_layers
[docs] def compute_map(self, x, class_idx=None, layer_idx=-1): logits, acti, _ = self.nn_module(x) acti = acti[layer_idx] if class_idx is None: class_idx = logits.max(1)[-1] b, c, *spatial = acti.shape acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D fc_layers = self.nn_module.get_layer(self.fc_layers) output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2) output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0) return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class
def __call__(self, x, class_idx=None, layer_idx=-1): """ Compute the activation map with upsampling and postprocessing. Args: x: input tensor, shape must be compatible with `nn_module`. class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. Returns: activation maps """ acti_map = self.compute_map(x, class_idx, layer_idx) return self._upsample_and_post_process(acti_map, x)
[docs]class GradCAM(CAMBase): """ Computes Gradient-weighted Class Activation Mapping (Grad-CAM). This implementation is based on: Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, https://arxiv.org/abs/1610.02391 Examples .. code-block:: python # densenet 2d from monai.networks.nets import DenseNet121 from monai.visualize import GradCAM model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu") result = cam(x=torch.rand((1, 1, 48, 64))) # resnet 2d from monai.networks.nets import se_resnet50 from monai.visualize import GradCAM model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) cam = GradCAM(nn_module=model_2d, target_layers="layer4") result = cam(x=torch.rand((2, 3, 48, 64))) N.B.: To help select the target layer, it may be useful to list all layers: .. code-block:: python for name, _ in model.named_modules(): print(name) See Also: - :py:class:`monai.visualize.class_activation_maps.CAM` """
[docs] def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map)
def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): """ Compute the activation map with upsampling and postprocessing. Args: x: input tensor, shape must be compatible with `nn_module`. class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. retain_graph: whether to retain_graph for torch module backward call. Returns: activation maps """ acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) return self._upsample_and_post_process(acti_map, x)
[docs]class GradCAMpp(GradCAM): """ Computes Gradient-weighted Class Activation Mapping (Grad-CAM++). This implementation is based on: Chattopadhyay et al., Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks, https://arxiv.org/abs/1710.11063 See Also: - :py:class:`monai.visualize.class_activation_maps.GradCAM` """
[docs] def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape alpha_nr = grad.pow(2) alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr)) alpha = alpha_nr.div(alpha_dr + 1e-7) relu_grad = F.relu(self.nn_module.score.exp() * grad) weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map)