# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Callable, Dict, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
from monai.transforms import ScaleIntensity
from monai.utils import InterpolateMode, ensure_tuple
__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM", "GradCAMpp"]
[docs]class ModelWithHooks:
"""
A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information.
"""
def __init__(
self,
nn_module,
target_layer_names: Union[str, Sequence[str]],
register_forward: bool = False,
register_backward: bool = False,
):
"""
Args:
nn_module: the model to be wrapped.
target_layer_names: the names of the layer to cache.
register_forward: whether to cache the forward pass output corresponding to `target_layer_names`.
register_backward: whether to cache the backward pass output corresponding to `target_layer_names`.
"""
self.model = nn_module
self.target_layers = ensure_tuple(target_layer_names)
self.gradients: Dict[str, torch.Tensor] = {}
self.activations: Dict[str, torch.Tensor] = {}
self.score = None
self.class_idx = None
self.register_backward = register_backward
self.register_forward = register_forward
_registered = []
for name, mod in nn_module.named_modules():
if name not in self.target_layers:
continue
_registered.append(name)
if self.register_backward:
mod.register_backward_hook(self.backward_hook(name))
if self.register_forward:
mod.register_forward_hook(self.forward_hook(name))
if len(_registered) != len(self.target_layers):
warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.")
def backward_hook(self, name):
def _hook(_module, _grad_input, grad_output):
self.gradients[name] = grad_output[0]
return _hook
def forward_hook(self, name):
def _hook(_module, _input, output):
self.activations[name] = output
return _hook
[docs] def get_layer(self, layer_id: Union[str, Callable]):
"""
Args:
layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`,
this method will return the module `self.model.fc`.
Returns:
a submodule from self.model.
"""
if callable(layer_id):
return layer_id(self.model)
if isinstance(layer_id, str):
for name, mod in self.model.named_modules():
if name == layer_id:
return mod
raise NotImplementedError(f"Could not find {layer_id}.")
def class_score(self, logits, class_idx=None):
if class_idx is not None:
return logits[:, class_idx].squeeze(), class_idx
class_idx = logits.max(1)[-1]
return logits[:, class_idx].squeeze(), class_idx
def __call__(self, x, class_idx=None, retain_graph=False):
logits = self.model(x)
acti, grad = None, None
if self.register_forward:
acti = tuple(self.activations[layer] for layer in self.target_layers)
if self.register_backward:
score, class_idx = self.class_score(logits, class_idx)
self.model.zero_grad()
self.score, self.class_idx = score, class_idx
score.sum().backward(retain_graph=retain_graph)
grad = tuple(self.gradients[layer] for layer in self.target_layers)
return logits, acti, grad
[docs]def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]:
"""
A linear interpolation method for upsampling the feature map.
The output of this function is a callable `func`,
such that `func(activation_map)` returns an upsampled tensor.
"""
def up(acti_map):
linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
interp_mode = linear_mode[len(spatial_size) - 1]
return F.interpolate(acti_map, size=spatial_size, mode=str(interp_mode.value), align_corners=False)
return up
[docs]def default_normalizer(acti_map) -> np.ndarray:
"""
A linear intensity scaling by mapping the (min, max) to (1, 0).
"""
if isinstance(acti_map, torch.Tensor):
acti_map = acti_map.detach().cpu().numpy()
scaler = ScaleIntensity(minv=1.0, maxv=0.0)
acti_map = [scaler(x) for x in acti_map]
return np.stack(acti_map, axis=0)
[docs]class CAM:
"""
Compute class activation map from the last fully-connected layers before the spatial pooling.
Examples
.. code-block:: python
# densenet 2d
from monai.networks.nets import densenet121
from monai.visualize import CAM
model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3)
cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out")
result = cam(x=torch.rand((1, 1, 48, 64)))
# resnet 2d
from monai.networks.nets import se_resnet50
from monai.visualize import CAM
model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4)
cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear")
result = cam(x=torch.rand((2, 3, 48, 64)))
See Also:
- :py:class:`monai.visualize.class_activation_maps.GradCAM`
"""
def __init__(
self,
nn_module,
target_layers: str,
fc_layers: Union[str, Callable] = "fc",
upsampler=default_upsampler,
postprocessing: Callable = default_normalizer,
):
"""
Args:
nn_module: the model to be visualised
target_layers: name of the model layer to generate the feature map.
fc_layers: a string or a callable used to get fully-connected weights to compute activation map
from the target_layers (without pooling). and evaluate it at every spatial location.
upsampler: an upsampling method to upsample the feature map.
postprocessing: a callable that applies on the upsampled feature map.
"""
if not isinstance(nn_module, ModelWithHooks):
self.net = ModelWithHooks(nn_module, target_layers, register_forward=True)
else:
self.net = nn_module
self.upsampler = upsampler
self.postprocessing = postprocessing
self.fc_layers = fc_layers
[docs] def compute_map(self, x, class_idx=None, layer_idx=-1):
"""
Compute the actual feature map with input tensor `x`.
"""
logits, acti, _ = self.net(x)
acti = acti[layer_idx]
if class_idx is None:
class_idx = logits.max(1)[-1]
b, c, *spatial = acti.shape
acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D
fc_layers = self.net.get_layer(self.fc_layers)
output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2)
output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0)
return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class
[docs] def feature_map_size(self, input_size, device="cpu", layer_idx=-1):
"""
Computes the actual feature map size given `nn_module` and the target_layer name.
Args:
input_size: shape of the input tensor
device: the device used to initialise the input tensor
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
Returns:
shape of the actual feature map.
"""
return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape
def __call__(self, x, class_idx=None, layer_idx=-1):
"""
Compute the activation map with upsampling and postprocessing.
Args:
x: input tensor, shape must be compatible with `nn_module`.
class_idx: index of the class to be visualised. Default to argmax(logits)
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
Returns:
activation maps
"""
acti_map = self.compute_map(x, class_idx, layer_idx)
# upsampling and postprocessing
if self.upsampler:
img_spatial = x.shape[2:]
acti_map = self.upsampler(img_spatial)(acti_map)
if self.postprocessing:
acti_map = self.postprocessing(acti_map)
return acti_map
[docs]class GradCAM:
"""
Computes Gradient-weighted Class Activation Mapping (Grad-CAM).
This implementation is based on:
Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization,
https://arxiv.org/abs/1610.02391
Examples
.. code-block:: python
# densenet 2d
from monai.networks.nets import densenet121
from monai.visualize import GradCAM
model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3)
cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu")
result = cam(x=torch.rand((1, 1, 48, 64)))
# resnet 2d
from monai.networks.nets import se_resnet50
from monai.visualize import GradCAM
model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4)
cam = GradCAM(nn_module=model_2d, target_layers="layer4")
result = cam(x=torch.rand((2, 3, 48, 64)))
See Also:
- :py:class:`monai.visualize.class_activation_maps.CAM`
"""
def __init__(self, nn_module, target_layers: str, upsampler=default_upsampler, postprocessing=default_normalizer):
"""
Args:
nn_module: the model to be used to generate the visualisations.
target_layers: name of the model layer to generate the feature map.
upsampler: an upsampling method to upsample the feature map.
postprocessing: a callable that applies on the upsampled feature map.
"""
if not isinstance(nn_module, ModelWithHooks):
self.net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=True)
else:
self.net = nn_module
self.upsampler = upsampler
self.postprocessing = postprocessing
[docs] def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1):
"""
Compute the actual feature map with input tensor `x`.
"""
logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial))
acti_map = (weights * acti).sum(1, keepdim=True)
return F.relu(acti_map)
[docs] def feature_map_size(self, input_size, device="cpu", layer_idx=-1):
"""
Computes the actual feature map size given `nn_module` and the target_layer name.
Args:
input_size: shape of the input tensor
device: the device used to initialise the input tensor
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
Returns:
shape of the actual feature map.
"""
return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape
def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False):
"""
Compute the activation map with upsampling and postprocessing.
Args:
x: input tensor, shape must be compatible with `nn_module`.
class_idx: index of the class to be visualised. Default to argmax(logits)
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
retain_graph: whether to retain_graph for torch module backward call.
Returns:
activation maps
"""
acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx)
# upsampling and postprocessing
if self.upsampler:
img_spatial = x.shape[2:]
acti_map = self.upsampler(img_spatial)(acti_map)
if self.postprocessing:
acti_map = self.postprocessing(acti_map)
return acti_map
[docs]class GradCAMpp(GradCAM):
"""
Computes Gradient-weighted Class Activation Mapping (Grad-CAM++).
This implementation is based on:
Chattopadhyay et al., Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks,
https://arxiv.org/abs/1710.11063
See Also:
- :py:class:`monai.visualize.class_activation_maps.GradCAM`
"""
[docs] def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1):
"""
Compute the actual feature map with input tensor `x`.
"""
logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
alpha_nr = grad.pow(2)
alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial))
alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr))
alpha = alpha_nr.div(alpha_dr + 1e-7)
relu_grad = F.relu(self.net.score.exp() * grad)
weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial))
acti_map = (weights * acti).sum(1, keepdim=True)
return F.relu(acti_map)