Source code for monai.networks.blocks.regunet_block

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from import Sequence

import torch
from torch import nn
from torch.nn import functional as F

from monai.networks.blocks import Convolution
from monai.networks.layers import Conv, Norm, Pool, same_padding

def get_conv_block(
    spatial_dims: int,
    in_channels: int,
    out_channels: int,
    kernel_size: Sequence[int] | int = 3,
    strides: int = 1,
    padding: tuple[int, ...] | int | None = None,
    act: tuple | str | None = "RELU",
    norm: tuple | str | None = "BATCH",
    initializer: str | None = "kaiming_uniform",
) -> nn.Module:
    if padding is None:
        padding = same_padding(kernel_size)
    conv_block: nn.Module = Convolution(
    conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
    for m in conv_block.modules():
        if isinstance(m, conv_type):
            if initializer == "kaiming_uniform":
            elif initializer == "zeros":
                raise ValueError(
                    f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros"
    return conv_block

def get_conv_layer(
    spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int = 3
) -> nn.Module:
    padding = same_padding(kernel_size)
    mod: nn.Module = Convolution(
        spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding
    return mod

[docs] class RegistrationResidualConvBlock(nn.Module): """ A block with skip links and layer - norm - activation. Only changes the number of channels, the spatial size is kept same. """
[docs] def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, num_layers: int = 2, kernel_size: int = 3 ): """ Args: spatial_dims: number of spatial dimensions in_channels: number of input channels out_channels: number of output channels num_layers: number of layers inside the block kernel_size: kernel_size """ super().__init__() self.num_layers = num_layers self.layers = nn.ModuleList( [ get_conv_layer( spatial_dims=spatial_dims, in_channels=in_channels if i == 0 else out_channels, out_channels=out_channels, kernel_size=kernel_size, ) for i in range(num_layers) ] ) self.norms = nn.ModuleList([Norm[Norm.BATCH, spatial_dims](out_channels) for _ in range(num_layers)]) self.acts = nn.ModuleList([nn.ReLU() for _ in range(num_layers)])
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3]) Returns: Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]), with the same spatial size as ``x`` """ skip = x for i, (conv, norm, act) in enumerate(zip(self.layers, self.norms, self.acts)): x = conv(x) x = norm(x) if i == self.num_layers - 1: # last block x = x + skip x = act(x) return x
[docs] class RegistrationDownSampleBlock(nn.Module): """ A down-sample module used in RegUNet to half the spatial size. The number of channels is kept same. Adapted from: DeepReg ( """
[docs] def __init__(self, spatial_dims: int, channels: int, pooling: bool) -> None: """ Args: spatial_dims: number of spatial dimensions. channels: channels pooling: use MaxPool if True, strided conv if False """ super().__init__() if pooling: self.layer = Pool[Pool.MAX, spatial_dims](kernel_size=2) else: self.layer = get_conv_block( spatial_dims=spatial_dims, in_channels=channels, out_channels=channels, kernel_size=2, strides=2, padding=0, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Halves the spatial dimensions and keeps the same channel. output in shape (batch, ``channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]), Args: x: Tensor in shape (batch, ``channels``, insize_1, insize_2, [insize_3]) Raises: ValueError: when input spatial dimensions are not even. """ for i in x.shape[2:]: if i % 2 != 0: raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}") out: torch.Tensor = self.layer(x) return out
def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module: mod: nn.Module = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=2, act="RELU", norm="BATCH", bias=False, is_transposed=True, padding=1, output_padding=1, ) return mod
[docs] class RegistrationExtractionBlock(nn.Module): """ The Extraction Block used in RegUNet. Extracts feature from each ``extract_levels`` and takes the average. """
[docs] def __init__( self, spatial_dims: int, extract_levels: tuple[int], num_channels: tuple[int] | list[int], out_channels: int, kernel_initializer: str | None = "kaiming_uniform", activation: str | None = None, mode: str = "nearest", align_corners: bool | None = None, ): """ Args: spatial_dims: number of spatial dimensions extract_levels: spatial levels to extract feature from, 0 refers to the input scale num_channels: number of channels at each scale level, List or Tuple of length equals to `depth` of the RegNet out_channels: number of output channels kernel_initializer: kernel initializer activation: kernel activation function mode: feature map interpolation mode, default to "nearest". align_corners: whether to align corners for feature map interpolation. """ super().__init__() self.extract_levels = extract_levels self.max_level = max(extract_levels) self.layers = nn.ModuleList( [ get_conv_block( spatial_dims=spatial_dims, in_channels=num_channels[d], out_channels=out_channels, norm=None, act=activation, initializer=kernel_initializer, ) for d in extract_levels ] ) self.mode = mode self.align_corners = align_corners
[docs] def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor: """ Args: x: Decoded feature at different spatial levels, sorted from deep to shallow image_size: output image size Returns: Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size`` """ feature_list = [ F.interpolate( layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners ) for layer, level in zip(self.layers, self.extract_levels) ] out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0) return out