Source code for monai.networks.blocks.crf

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn.functional import conv1d, conv2d, conv3d, pad, softmax

from monai.networks.layers.filtering import PHLFilter

__all__ = ["CRF"]

[docs]class CRF(torch.nn.Module): """ Conditional Random Field: Combines message passing with a class compatibility convolution into an iterative process designed to successively minimise the energy of the class labeling. In this implementation, the message passing step is a weighted combination of a gaussian filter and a bilateral filter. The bilateral term is included to respect existing structure within the reference tensor. See: """ def __init__( self, bilateral_weight: float = 1.0, gaussian_weight: float = 1.0, bilateral_spatial_sigma: float = 5.0, bilateral_color_sigma: float = 0.5, gaussian_spatial_sigma: float = 5.0, update_factor: float = 3.0, compatibility_kernel_range: int = 1, iterations: int = 5, ): """ Args: bilateral_weight: the weighting of the bilateral term in the message passing step. gaussian_weight: the weighting of the gaussian term in the message passing step. bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term. bilateral_color_sigma: standard deviation in color space for the bilateral term. gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. update_factor: determines the magnitude of each update. compatibility_kernel_range: the range of the kernel used in the compatibility convolution. iterations: the number of iterations. """ super(CRF, self).__init__() self.bilateral_weight = bilateral_weight self.gaussian_weight = gaussian_weight self.bilateral_spatial_sigma = bilateral_spatial_sigma self.bilateral_color_sigma = bilateral_color_sigma self.gaussian_spatial_sigma = gaussian_spatial_sigma self.update_factor = update_factor self.compatibility_kernel_range = compatibility_kernel_range self.iterations = iterations
[docs] def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): """ Args: input_tensor: tensor containing initial class logits. reference_tensor: the reference tensor used to guide the message passing. Returns: output (torch.Tensor): output tensor. """ # useful values spatial_dim = input_tensor.dim() - 2 class_count = input_tensor.size(1) padding = self.compatibility_kernel_range # constructing spatial feature tensor spatial_features = _create_coordinate_tensor(reference_tensor) # constructing final feature tensors for bilateral and gaussian kernel bilateral_features = [spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma], dim=1 ) gaussian_features = spatial_features / self.gaussian_spatial_sigma # compatibility matrix (potts model (1 - diag) for now) compatibility_matrix = _potts_model_weights(class_count).to(device=input_tensor.device) # expanding matrix to kernel compatibility_kernel = _expand_matrix_to_kernel( compatibility_matrix, spatial_dim, self.compatibility_kernel_range ) # choosing convolution function conv = [conv1d, conv2d, conv3d][spatial_dim - 1] # setting up output tensor output_tensor = softmax(input_tensor, dim=1) # mean field loop for _ in range(self.iterations): # message passing step for both kernels bliateral_output = PHLFilter.apply(output_tensor, bilateral_features) gaussian_output = PHLFilter.apply(output_tensor, gaussian_features) # combining filter outputs combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output # compatibility convolution combined_output = pad(combined_output, 2 * spatial_dim * [padding], mode="replicate") compatibility_update = conv(combined_output, compatibility_kernel) # update and normalize output_tensor = softmax(input_tensor - self.update_factor * compatibility_update, dim=1) return output_tensor
# helper methods def _create_coordinate_tensor(tensor): axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())] grids = torch.meshgrid(axes) coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype) return torch.stack(tensor.size(0) * [coords], dim=0) def _potts_model_weights(class_count): return (1 - torch.diag(torch.ones(class_count))).unsqueeze(-1) def _expand_matrix_to_kernel(matrix, spatial_dim, kernel_range): reshape_arg = (matrix.size(0), matrix.size(1)) + spatial_dim * (1,) expand_arg = (-1, -1) + spatial_dim * (1 + 2 * kernel_range,) return matrix.reshape(reshape_arg).expand(expand_arg)