Source code for monai.losses.spatial_mask

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import inspect
import warnings
from import Callable
from typing import Any, Optional

import torch
from torch.nn.modules.loss import _Loss

__all__ = ["MaskedLoss"]

[docs] class MaskedLoss(_Loss): """ This is a wrapper class for the loss functions. It allows for additional weighting masks to be applied to both input and target. See Also: - :py:class:`monai.losses.MaskedDiceLoss` """
[docs] def __init__( self, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | _Loss, *loss_args: Any, **loss_kwargs: Any ) -> None: """ Args: loss: loss function to be wrapped, this could be a loss class or an instance of a loss class. loss_args: arguments to the loss function's constructor if `loss` is a class. loss_kwargs: keyword arguments to the loss function's constructor if `loss` is a class. """ super().__init__() self.loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = ( loss(*loss_args, **loss_kwargs) if inspect.isclass(loss) else loss ) if not callable(self.loss): raise ValueError("The loss function is not callable.")
[docs] def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. mask: the shape should be B1H[WD] or 11H[WD]. """ if mask is None: warnings.warn("No mask value specified for the MaskedLoss.") return self.loss(input, target) if input.dim() != mask.dim(): warnings.warn(f"Dim of input ({input.shape}) is different from mask ({mask.shape}).") if input.shape[0] != mask.shape[0] and mask.shape[0] != 1: raise ValueError(f"Batch size of mask ({mask.shape}) must be one or equal to input ({input.shape}).") if target.dim() > 1: if mask.shape[1] != 1: raise ValueError(f"Mask ({mask.shape}) must have only one channel.") if input.shape[2:] != mask.shape[2:]: warnings.warn(f"Spatial size of input ({input.shape}) is different from mask ({mask.shape}).") return self.loss(input * mask, target * mask)