Source code for monai.losses.cldice

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss


def soft_erode(img: torch.Tensor) -> torch.Tensor:  # type: ignore
    """
    Perform soft erosion on the input image

    Args:
        img: the shape should be BCH(WD)

    Adapted from:
        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6
    """
    if len(img.shape) == 4:
        p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)))
        p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)))
        return torch.min(p1, p2)
    elif len(img.shape) == 5:
        p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)))
        p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)))
        p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)))
        return torch.min(torch.min(p1, p2), p3)


def soft_dilate(img: torch.Tensor) -> torch.Tensor:  # type: ignore
    """
    Perform soft dilation on the input image

    Args:
        img: the shape should be BCH(WD)

    Adapted from:
        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18
    """
    if len(img.shape) == 4:
        return F.max_pool2d(img, (3, 3), (1, 1), (1, 1))
    elif len(img.shape) == 5:
        return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1))


def soft_open(img: torch.Tensor) -> torch.Tensor:
    """
    Wrapper function to perform soft opening on the input image

    Args:
        img: the shape should be BCH(WD)

    Adapted from:
        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25
    """
    eroded_image = soft_erode(img)
    dilated_image = soft_dilate(eroded_image)
    return dilated_image


def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
    """
    Perform soft skeletonization on the input image

    Adapted from:
       https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29

    Args:
        img: the shape should be BCH(WD)
        iter_: number of iterations for skeletonization

    Returns:
        skeletonized image
    """
    img1 = soft_open(img)
    skel = F.relu(img - img1)
    for _ in range(iter_):
        img = soft_erode(img)
        img1 = soft_open(img)
        delta = F.relu(img - img1)
        skel = skel + F.relu(delta - skel * delta)
    return skel


def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
    """
    Function to compute soft dice loss

    Adapted from:
        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22

    Args:
        y_true: the shape should be BCH(WD)
        y_pred: the shape should be BCH(WD)

    Returns:
        dice loss
    """
    intersection = torch.sum((y_true * y_pred)[:, 1:, ...])
    coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)
    soft_dice: torch.Tensor = 1.0 - coeff
    return soft_dice


[docs] class SoftclDiceLoss(_Loss): """ Compute the Soft clDice loss defined in: Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) Adapted from: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 """
[docs] def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: """ Args: iter_: Number of iterations for skeletonization smooth: Smoothing parameter """ super().__init__() self.iter = iter_ self.smooth = smooth
[docs] def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: skel_pred = soft_skel(y_pred, self.iter) skel_true = soft_skel(y_true, self.iter) tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_pred[:, 1:, ...]) + self.smooth ) tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_true[:, 1:, ...]) + self.smooth ) cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) return cl_dice
[docs] class SoftDiceclDiceLoss(_Loss): """ Compute the Soft clDice loss defined in: Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) Adapted from: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 """
[docs] def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: """ Args: iter_: Number of iterations for skeletonization smooth: Smoothing parameter alpha: Weighing factor for cldice """ super().__init__() self.iter = iter_ self.smooth = smooth self.alpha = alpha
[docs] def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: dice = soft_dice(y_true, y_pred, self.smooth) skel_pred = soft_skel(y_pred, self.iter) skel_true = soft_skel(y_true, self.iter) tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_pred[:, 1:, ...]) + self.smooth ) tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_true[:, 1:, ...]) + self.smooth ) cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice return total_loss