# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
from monai.losses import DiceLoss
from monai.transforms import SobelGradients
from monai.utils.enums import HoVerNetBranch
[docs]
class HoVerNetLoss(_Loss):
"""
Loss function for HoVerNet pipeline, which is combination of losses across the three branches.
The NP (nucleus prediction) branch uses Dice + CrossEntropy.
The HV (Horizontal and Vertical) distance from centroid branch uses MSE + MSE of the gradient.
The NC (Nuclear Class prediction) branch uses Dice + CrossEntropy
The result is a weighted sum of these losses.
Args:
lambda_hv_mse: Weight factor to apply to the HV regression MSE part of the overall loss
lambda_hv_mse_grad: Weight factor to apply to the MSE of the HV gradient part of the overall loss
lambda_np_ce: Weight factor to apply to the nuclei prediction CrossEntropyLoss part
of the overall loss
lambda_np_dice: Weight factor to apply to the nuclei prediction DiceLoss part of overall loss
lambda_nc_ce: Weight factor to apply to the nuclei class prediction CrossEntropyLoss part
of the overall loss
lambda_nc_dice: Weight factor to apply to the nuclei class prediction DiceLoss part of the
overall loss
"""
def __init__(
self,
lambda_hv_mse: float = 2.0,
lambda_hv_mse_grad: float = 1.0,
lambda_np_ce: float = 1.0,
lambda_np_dice: float = 1.0,
lambda_nc_ce: float = 1.0,
lambda_nc_dice: float = 1.0,
) -> None:
self.lambda_hv_mse = lambda_hv_mse
self.lambda_hv_mse_grad = lambda_hv_mse_grad
self.lambda_np_ce = lambda_np_ce
self.lambda_np_dice = lambda_np_dice
self.lambda_nc_ce = lambda_nc_ce
self.lambda_nc_dice = lambda_nc_dice
super().__init__()
self.dice = DiceLoss(softmax=True, smooth_dr=1e-03, smooth_nr=1e-03, reduction="sum", batch=True)
self.ce = CrossEntropyLoss(reduction="mean")
self.sobel_v = SobelGradients(kernel_size=5, spatial_axes=0)
self.sobel_h = SobelGradients(kernel_size=5, spatial_axes=1)
def _compute_sobel(self, image: torch.Tensor) -> torch.Tensor:
"""Compute the Sobel gradients of the horizontal vertical map (HoVerMap).
More specifically, it will compute horizontal gradient of the input horizontal gradient map (channel=0) and
vertical gradient of the input vertical gradient map (channel=1).
Args:
image: a tensor with the shape of BxCxHxW representing HoVerMap
"""
result_h = self.sobel_h(image[:, 0])
result_v = self.sobel_v(image[:, 1])
return torch.stack([result_h, result_v], dim=1)
def _mse_gradient_loss(self, prediction: torch.Tensor, target: torch.Tensor, focus: torch.Tensor) -> torch.Tensor:
"""Compute the MSE loss of the gradients of the horizontal and vertical centroid distance maps"""
pred_grad = self._compute_sobel(prediction)
true_grad = self._compute_sobel(target)
loss = pred_grad - true_grad
# The focus constrains the loss computation to the detected nuclear regions
# (i.e. background is excluded)
focus = focus[:, None, ...]
focus = torch.cat((focus, focus), 1)
loss = focus * (loss * loss)
loss = loss.sum() / (focus.sum() + 1.0e-8)
return loss
[docs]
def forward(self, prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Args:
prediction: dictionary of predicted outputs for three branches,
each of which should have the shape of BNHW.
target: dictionary of ground truths for three branches,
each of which should have the shape of BNHW.
"""
if not (HoVerNetBranch.NP.value in prediction and HoVerNetBranch.HV.value in prediction):
raise ValueError(
"nucleus prediction (NP) and horizontal_vertical (HV) branches must be "
"present for prediction and target parameters"
)
if not (HoVerNetBranch.NP.value in target and HoVerNetBranch.HV.value in target):
raise ValueError(
"nucleus prediction (NP) and horizontal_vertical (HV) branches must be "
"present for prediction and target parameters"
)
if HoVerNetBranch.NC.value not in target and HoVerNetBranch.NC.value in target:
raise ValueError(
"type_prediction (NC) must be present in both or neither of the prediction and target parameters"
)
if HoVerNetBranch.NC.value in target and HoVerNetBranch.NC.value not in target:
raise ValueError(
"type_prediction (NC) must be present in both or neither of the prediction and target parameters"
)
# Compute the NP branch loss
dice_loss_np = (
self.dice(prediction[HoVerNetBranch.NP.value], target[HoVerNetBranch.NP.value]) * self.lambda_np_dice
)
# convert to target class indices
argmax_target = target[HoVerNetBranch.NP.value].argmax(dim=1)
ce_loss_np = self.ce(prediction[HoVerNetBranch.NP.value], argmax_target) * self.lambda_np_ce
loss_np = dice_loss_np + ce_loss_np
# Compute the HV branch loss
loss_hv_mse = (
F.mse_loss(prediction[HoVerNetBranch.HV.value], target[HoVerNetBranch.HV.value]) * self.lambda_hv_mse
)
# Use the nuclei class, one hot encoded, as the mask
loss_hv_mse_grad = (
self._mse_gradient_loss(
prediction[HoVerNetBranch.HV.value],
target[HoVerNetBranch.HV.value],
target[HoVerNetBranch.NP.value][:, 1],
)
* self.lambda_hv_mse_grad
)
loss_hv = loss_hv_mse_grad + loss_hv_mse
# Compute the NC branch loss
loss_nc = 0
if HoVerNetBranch.NC.value in prediction:
dice_loss_nc = (
self.dice(prediction[HoVerNetBranch.NC.value], target[HoVerNetBranch.NC.value]) * self.lambda_nc_dice
)
# Convert to target class indices
argmax_target = target[HoVerNetBranch.NC.value].argmax(dim=1)
ce_loss_nc = self.ce(prediction[HoVerNetBranch.NC.value], argmax_target) * self.lambda_nc_ce
loss_nc = dice_loss_nc + ce_loss_nc
# Sum the losses from each branch
loss: torch.Tensor = loss_hv + loss_np + loss_nc
return loss