Source code for monai.losses.deform

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch
from torch.nn.modules.loss import _Loss

from monai.utils import LossReduction

def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor:
    Calculate gradients on single dimension of a tensor using central finite difference.
    It moves the tensor along the dimension to calculate the approximate gradient
    dx[i] = (x[i+1] - x[i-1]) / 2.
    Adapted from:
        DeepReg (

        x: the shape should be BCH(WD).
        dim: dimension to calculate gradient along.
        gradient_dx: the shape should be BCH(WD)
    slice_1 = slice(1, -1)
    slice_2_s = slice(2, None)
    slice_2_e = slice(None, -2)
    slice_all = slice(None)
    slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all]
    while len(slicing_s) < x.ndim:
        slicing_s = slicing_s + [slice_1]
        slicing_e = slicing_e + [slice_1]
    slicing_s[dim] = slice_2_s
    slicing_e[dim] = slice_2_e
    return (x[slicing_s] - x[slicing_e]) / 2.0

[docs]class BendingEnergyLoss(_Loss): """ Calculate the bending energy based on second-order differentiation of pred using central finite difference. Adapted from: DeepReg ( """
[docs] def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: """ Args: reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. """ super().__init__(reduction=LossReduction(reduction).value)
[docs] def forward(self, pred: torch.Tensor) -> torch.Tensor: """ Args: pred: the shape should be BCH(WD) Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if pred.ndim not in [3, 4, 5]: raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): if pred.shape[-i - 1] <= 4: raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") # first order gradient first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 energy = spatial_gradient(g, dim_1) ** 2 + energy for dim_2 in range(dim_1 + 1, pred.ndim): energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average elif self.reduction == LossReduction.SUM.value: energy = torch.sum(energy) # sum over the batch and channel dims elif self.reduction != LossReduction.NONE.value: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy