Source code for monai.losses.adversarial_loss

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
from torch.nn.modules.loss import _Loss

from monai.networks.layers.utils import get_act_layer
from monai.utils import LossReduction
from monai.utils.enums import StrEnum

class AdversarialCriterions(StrEnum):
    BCE = "bce"
    HINGE = "hinge"
    LEAST_SQUARE = "least_squares"

[docs] class PatchAdversarialLoss(_Loss): """ Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator. Warning: due to the possibility of using different criterions, the output of the discrimination mustn't be passed to a final activation layer. That is taken care of internally within the loss. Args: reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs through an activation layer prior to calling the loss. no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. """ def __init__( self, reduction: LossReduction | str = LossReduction.MEAN, criterion: str = AdversarialCriterions.LEAST_SQUARE, no_activation_leastsq: bool = False, ) -> None: super().__init__(reduction=LossReduction(reduction)) if criterion.lower() not in list(AdversarialCriterions): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" % ", ".join(AdversarialCriterions) ) # Depending on the criterion, a different activation layer is used. self.real_label = 1.0 self.fake_label = 0.0 self.loss_fct: _Loss if criterion == AdversarialCriterions.BCE: self.activation = get_act_layer("SIGMOID") self.loss_fct = torch.nn.BCELoss(reduction=reduction) elif criterion == AdversarialCriterions.HINGE: self.activation = get_act_layer("TANH") self.fake_label = -1.0 elif criterion == AdversarialCriterions.LEAST_SQUARE: if no_activation_leastsq: self.activation = None else: self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) self.loss_fct = torch.nn.MSELoss(reduction=reduction) self.criterion = criterion self.reduction = reduction
[docs] def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor: """ Gets the ground truth tensor for the discriminator depending on whether the input is real or fake. Args: input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale discriminator). This is used to match the shape. target_is_real: whether the input is real or wannabe-real (1s) or fake (0s). Returns: """ filling_label = self.real_label if target_is_real else self.fake_label label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device) label_tensor.requires_grad_(False) return label_tensor.expand_as(input)
[docs] def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor: """ Gets a zero tensor. Args: input: tensor which shape you want the zeros tensor to correspond to. Returns: """ zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device) zero_label_tensor.requires_grad_(False) return zero_label_tensor.expand_as(input)
[docs] def forward( self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool ) -> torch.Tensor | list[torch.Tensor]: """ Args: input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors or a tensor; they shouldn't have gone through an activation layer. target_is_real: whereas the input corresponds to discriminator output for real or fake images for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last case, target_is_real is set to True, as the generator wants the input to be dimmed as real. Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the summed or mean loss over the tensor and discriminator/s. """ if not for_discriminator and not target_is_real: target_is_real = True # With generator, we always want this to be true! warnings.warn( "Variable target_is_real has been set to False, but for_discriminator is set" "to False. To optimise a generator, target_is_real must be set to True." ) if not isinstance(input, list): input = [input] target_ = [] for _, disc_out in enumerate(input): if self.criterion != AdversarialCriterions.HINGE: target_.append(self.get_target_tensor(disc_out, target_is_real)) else: target_.append(self.get_zero_tensor(disc_out)) # Loss calculation loss_list = [] for disc_ind, disc_out in enumerate(input): if self.activation is not None: disc_out = self.activation(disc_out) if self.criterion == AdversarialCriterions.HINGE and not target_is_real: loss_ = self._forward_single(-disc_out, target_[disc_ind]) else: loss_ = self._forward_single(disc_out, target_[disc_ind]) loss_list.append(loss_) loss: torch.Tensor | list[torch.Tensor] if loss_list is not None: if self.reduction == LossReduction.MEAN: loss = torch.mean(torch.stack(loss_list)) elif self.reduction == LossReduction.SUM: loss = torch.sum(torch.stack(loss_list)) else: loss = loss_list return loss
def _forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: forward: torch.Tensor if self.criterion == AdversarialCriterions.BCE or self.criterion == AdversarialCriterions.LEAST_SQUARE: forward = self.loss_fct(input, target) elif self.criterion == AdversarialCriterions.HINGE: minval = torch.min(input - 1, self.get_zero_tensor(input)) forward = -torch.mean(minval) return forward