Source code for monai.losses.focal_loss

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Optional

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.networks import one_hot
from monai.utils import LossReduction


[docs] class FocalLoss(_Loss): """ FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from high confidence correct predictions. Reimplementation of the Focal Loss described in: - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017 - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy", Zhu et al., Medical Physics 2018 Example: >>> import torch >>> from monai.losses import FocalLoss >>> from torch.nn import BCEWithLogitsLoss >>> shape = B, N, *DIMS = 2, 3, 5, 7, 11 >>> input = torch.rand(*shape) >>> target = torch.rand(*shape) >>> # Demonstrate equivalence to BCE when gamma=0 >>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0) >>> fl_g0_loss = fl_g0_criterion(input, target) >>> bce_criterion = BCEWithLogitsLoss(reduction='none') >>> bce_loss = bce_criterion(input, target) >>> assert torch.allclose(fl_g0_loss, bce_loss) >>> # Demonstrate "focus" by setting gamma > 0. >>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2) >>> fl_g2_loss = fl_g2_criterion(input, target) >>> # Mark easy and hard cases >>> is_easy = (target > 0.7) & (input > 0.7) >>> is_hard = (target > 0.7) & (input < 0.3) >>> easy_loss_g0 = fl_g0_loss[is_easy].mean() >>> hard_loss_g0 = fl_g0_loss[is_hard].mean() >>> easy_loss_g2 = fl_g2_loss[is_easy].mean() >>> hard_loss_g2 = fl_g2_loss[is_hard].mean() >>> # Gamma > 0 causes the loss function to "focus" on the hard >>> # cases. IE, easy cases are downweighted, so hard cases >>> # receive a higher proportion of the loss. >>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2 >>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0 >>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0 """
[docs] def __init__( self, include_background: bool = True, to_onehot_y: bool = False, gamma: float = 2.0, alpha: float | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, ) -> None: """ Args: include_background: if False, channel index 0 (background category) is excluded from the loss calculation. If False, `alpha` is invalid when using softmax. to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False. gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in [0, 1]. Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. Example: >>> import torch >>> from monai.losses import FocalLoss >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) >>> fl = FocalLoss(to_onehot_y=True) >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma self.alpha = alpha self.weight = weight self.use_softmax = use_softmax self.register_buffer("class_weight", torch.ones(1))
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transformed by a sigmoid/softmax in the forward function. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: ValueError: When input and target (after one hot transform if set) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. ValueError: When ``self.weight`` is a sequence and the length is not equal to the number of classes. ValueError: When ``self.weight`` is/contains a value that is less than 0. """ n_pred_ch = input.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") loss: Optional[torch.Tensor] = None input = input.float() target = target.float() if self.use_softmax: if not self.include_background and self.alpha is not None: self.alpha = None warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") loss = softmax_focal_loss(input, target, self.gamma, self.alpha) else: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) if self.weight is not None: # make sure the lengths of weights are equal to the number of classes num_of_classes = target.shape[1] if isinstance(self.weight, (float, int)): self.class_weight = torch.as_tensor([self.weight] * num_of_classes) else: self.class_weight = torch.as_tensor(self.weight) if self.class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. If `include_background=False`, the weight should not include the background category class 0.""" ) if self.class_weight.min() < 0: raise ValueError("the value/values of the `weight` should be no less than 0.") # apply class_weight to loss self.class_weight = self.class_weight.to(loss) broadcast_dims = [-1] + [1] * len(target.shape[2:]) self.class_weight = self.class_weight.view(broadcast_dims) loss = self.class_weight * loss if self.reduction == LossReduction.SUM.value: # Previously there was a mean over the last dimension, which did not # return a compatible BCE loss. To maintain backwards compatible # behavior we have a flag that performs this extra step, disable or # parameterize if necessary. (Or justify why the mean should be there) average_spatial_dims = True if average_spatial_dims: loss = loss.mean(dim=list(range(2, len(target.shape)))) loss = loss.sum() elif self.reduction == LossReduction.MEAN.value: loss = loss.mean() elif self.reduction == LossReduction.NONE.value: pass else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return loss
def softmax_focal_loss( input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and s_j is the unnormalized score for class j. """ input_ls = input.log_softmax(1) loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: # (1-alpha) for the background class and alpha for the other classes alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_fac = alpha_fac.view(broadcast_dims) loss = alpha_fac * loss return loss def sigmoid_focal_loss( input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0 """ # computing binary cross entropy with logits # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 max_val = (-input).clamp(min=0) loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> # 1-p if t==1; p if t==0 <=> # pfac, that is, the term (1 - pt) invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow # (pfac.log() * gamma).exp() <=> # pfac.log().exp() ^ gamma <=> # pfac ^ gamma loss = (invprobs * gamma).exp() * loss if alpha is not None: # alpha if t==1; (1-alpha) if t==0 alpha_factor = target * alpha + (1 - target) * (1 - alpha) loss = alpha_factor * loss return loss