Source code for monai.losses.ssim_loss

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn.modules.loss import _Loss

from monai.metrics.regression import SSIMMetric


[docs]class SSIMLoss(_Loss): """ Build a Pytorch version of the SSIM loss function based on the original formula of SSIM Modified and adopted from: https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py For more info, visit https://vicuesoft.com/glossary/term/ssim-ms-ssim/ SSIM reference paper: Wang, Zhou, et al. "Image quality assessment: from error visibility to structural similarity." IEEE transactions on image processing 13.4 (2004): 600-612. """
[docs] def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2): """ Args: win_size: gaussian weighting window size k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator spatial_dims: if 2, input shape is expected to be (B,C,H,W). if 3, it is expected to be (B,C,H,W,D) """ super().__init__() self.win_size = win_size self.k1, self.k2 = k1, k2 self.spatial_dims = spatial_dims
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor: """ Args: x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D and pseudo-3D data, and (B,C,W,H,D) for 3D data, y: second sample (e.g., the reconstructed image). It has similar shape as x. data_range: dynamic range of the data Returns: 1-ssim_value (recall this is meant to be a loss function) Example: .. code-block:: python import torch # 2D data x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 data_range = x.max().unsqueeze(0) # the following line should print 1.0 (or 0.9999) print(1-SSIMLoss(spatial_dims=2)(x,y,data_range)) # pseudo-3D data x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices y = torch.ones([1,5,10,10])/2 data_range = x.max().unsqueeze(0) # the following line should print 1.0 (or 0.9999) print(1-SSIMLoss(spatial_dims=2)(x,y,data_range)) # 3D data x = torch.ones([1,1,10,10,10])/2 y = torch.ones([1,1,10,10,10])/2 data_range = x.max().unsqueeze(0) # the following line should print 1.0 (or 0.9999) print(1-SSIMLoss(spatial_dims=3)(x,y,data_range)) """ if x.shape[0] == 1: ssim_value: torch.Tensor = SSIMMetric( data_range, self.win_size, self.k1, self.k2, self.spatial_dims )._compute_tensor(x, y) elif x.shape[0] > 1: for i in range(x.shape[0]): ssim_val: torch.Tensor = SSIMMetric( data_range, self.win_size, self.k1, self.k2, self.spatial_dims )._compute_tensor(x[i : i + 1], y[i : i + 1]) if i == 0: ssim_value = ssim_val else: ssim_value = torch.cat((ssim_value.view(1), ssim_val.view(1)), dim=0) else: raise ValueError("Batch size is not nonnegative integer value") # 1- dimensional tensor is only allowed ssim_value = ssim_value.view(-1, 1) loss: torch.Tensor = 1 - ssim_value.mean() return loss