# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
import torch
import torch.nn as nn
from monai.utils import optional_import
from monai.utils.enums import StrEnum
LPIPS, _ = optional_import("lpips", name="LPIPS")
torchvision, _ = optional_import("torchvision")
class PercetualNetworkType(StrEnum):
alex = "alex"
vgg = "vgg"
squeeze = "squeeze"
radimagenet_resnet50 = "radimagenet_resnet50"
medicalnet_resnet10_23datasets = "medicalnet_resnet10_23datasets"
medical_resnet50_23datasets = "medical_resnet50_23datasets"
resnet50 = "resnet50"
[docs]
class PerceptualLoss(nn.Module):
"""
Perceptual loss using features from pretrained deep neural networks trained. The function supports networks
pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep
features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An
Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"
https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for
3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ;
and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html .
The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all
three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss.
Args:
spatial_dims: number of spatial dimensions.
network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``,
``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``}
Specifies the network architecture to use. Defaults to ``"alex"``.
is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.
fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.
cache_dir: path to cache directory to save the pretrained network weights.
pretrained: whether to load pretrained weights. This argument only works when using networks from
LIPIS or Torchvision. Defaults to ``"True"``.
pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
via using this argument. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
"""
def __init__(
self,
spatial_dims: int,
network_type: str = PercetualNetworkType.alex,
is_fake_3d: bool = True,
fake_3d_ratio: float = 0.5,
cache_dir: str | None = None,
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
):
super().__init__()
if spatial_dims not in [2, 3]:
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")
if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type:
raise ValueError(
"MedicalNet networks are only compatible with ``spatial_dims=3``."
"Argument is_fake_3d must be set to False."
)
if network_type.lower() not in list(PercetualNetworkType):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
% ", ".join(PercetualNetworkType)
)
if cache_dir:
torch.hub.set_dir(cache_dir)
# raise a warning that this may change the default cache dir for all torch.hub calls
warnings.warn(
f"Setting cache_dir to {cache_dir}, this may change the default cache dir for all torch.hub calls."
)
self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
elif network_type == "resnet50":
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
net=network_type,
pretrained=pretrained,
pretrained_path=pretrained_path,
pretrained_state_dict_key=pretrained_state_dict_key,
)
else:
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio
def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatial_axis: int) -> torch.Tensor:
"""
Calculate perceptual loss in one of the axis used in the 2.5D approach. After the slices of one spatial axis
is transformed into different instances in the batch, we compute the loss using the 2D approach.
Args:
input: input 5D tensor. BNHWD
target: target 5D tensor. BNHWD
spatial_axis: spatial axis to obtain the 2D slices.
"""
def batchify_axis(x: torch.Tensor, fake_3d_perm: tuple) -> torch.Tensor:
"""
Transform slices from one spatial axis into different instances in the batch.
"""
slices = x.float().permute((0,) + fake_3d_perm).contiguous()
slices = slices.view(-1, x.shape[fake_3d_perm[1]], x.shape[fake_3d_perm[2]], x.shape[fake_3d_perm[3]])
return slices
preserved_axes = [2, 3, 4]
preserved_axes.remove(spatial_axis)
channel_axis = 1
input_slices = batchify_axis(x=input, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))
indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)].to(
input_slices.device
)
input_slices = torch.index_select(input_slices, dim=0, index=indices)
target_slices = batchify_axis(x=target, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))
target_slices = torch.index_select(target_slices, dim=0, index=indices)
axis_loss = torch.mean(self.perceptual_function(input_slices, target_slices))
return axis_loss
[docs]
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNHW[D].
target: the shape should be BNHW[D].
"""
if target.shape != input.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
if self.spatial_dims == 3 and self.is_fake_3d:
# Compute 2.5D approach
loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2)
loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3)
loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4)
loss = loss_sagittal + loss_axial + loss_coronal
else:
# 2D and real 3D cases
loss = self.perceptual_function(input, target)
return torch.mean(loss)
class MedicalNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer
Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from
"Warvito/MedicalNet-models".
Args:
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
"""
def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the
pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across
the channels. Finally, we compute the difference between the input and target features and calculate the mean
value from the spatial dimensions to obtain the perceptual loss.
Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.
"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)
# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)
# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)
results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
return results
def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
return x.mean([2, 3, 4], keepdim=keepdim)
def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def medicalnet_intensity_normalisation(volume):
"""Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133"""
mean = volume.mean()
std = volume.std()
return (volume - mean) / std
class RadImageNetPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et
al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class
uses torch Hub to download the networks from "Warvito/radimagenet-models".
Args:
net: {``"radimagenet_resnet50"``}
Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``.
verbose: if false, mute messages from torch Hub load function.
"""
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
self.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from
'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised
across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
"""
# If input has just 1 channel, repeat channel to have 3 channels
if input.shape[1] == 1 and target.shape[1] == 1:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
# Change order from 'RGB' to 'BGR'
input = input[:, [2, 1, 0], ...]
target = target[:, [2, 1, 0], ...]
# Subtract mean used during training
input = subtract_mean(input)
target = subtract_mean(target)
# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)
# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)
results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
return results
class TorchvisionModelPerceptualSimilarity(nn.Module):
"""
Component to perform the perceptual evaluation with TorchVision models.
Currently, only ResNet50 is supported. The network structure is based on:
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
Args:
net: {``"resnet50"``}
Specifies the network architecture to use. Defaults to ``"resnet50"``.
pretrained: whether to load pretrained weights. Defaults to `True`.
pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded
via using this argument. Defaults to `None`.
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. Defaults to `None`.
"""
def __init__(
self,
net: str = "resnet50",
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
) -> None:
super().__init__()
supported_networks = ["resnet50"]
if net not in supported_networks:
raise NotImplementedError(
f"'net' {net} is not supported, please select a network from {supported_networks}."
)
if pretrained_path is None:
network = torchvision.models.resnet50(
weights=torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None
)
else:
network = torchvision.models.resnet50(weights=None)
if pretrained is True:
state_dict = torch.load(pretrained_path)
if pretrained_state_dict_key is not None:
state_dict = state_dict[pretrained_state_dict_key]
network.load_state_dict(state_dict)
self.final_layer = "layer4.2.relu_2"
self.model = torchvision.models.feature_extraction.create_feature_extractor(network, [self.final_layer])
self.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,
we make sure that the input and target have 3 channels, and then do Z-Score normalization.
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar
approach to the lpips package).
"""
# If input has just 1 channel, repeat channel to have 3 channels
if input.shape[1] == 1 and target.shape[1] == 1:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
# Input normalization
input = torchvision_zscore_norm(input)
target = torchvision_zscore_norm(target)
# Get model outputs
outs_input = self.model.forward(input)[self.final_layer]
outs_target = self.model.forward(target)[self.final_layer]
# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)
results: torch.Tensor = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
return results
def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
return x.mean([2, 3], keepdim=keepdim)
def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0]
x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1]
x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2]
return x
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
mean = [0.406, 0.456, 0.485]
x[:, 0, :, :] -= mean[0]
x[:, 1, :, :] -= mean[1]
x[:, 2, :, :] -= mean[2]
return x