Source code for monai.networks.blocks.dints_block

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch

from monai.networks.layers.factories import Conv
from monai.networks.layers.utils import get_act_layer, get_norm_layer

__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"]


[docs] class FactorizedIncreaseBlock(torch.nn.Sequential): """ Up-sampling the features by two using linear interpolation and convolutions. """
[docs] def __init__( self, in_channel: int, out_channel: int, spatial_dims: int = 3, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: in_channel: number of input channels out_channel: number of output channels spatial_dims: number of spatial dimensions act_name: activation layer type and arguments. norm_name: feature normalization type and arguments. """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel self._spatial_dims = spatial_dims if self._spatial_dims not in (2, 3): raise ValueError("spatial_dims must be 2 or 3.") conv_type = Conv[Conv.CONV, self._spatial_dims] mode = "trilinear" if self._spatial_dims == 3 else "bilinear" self.add_module("up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True)) self.add_module("acti", get_act_layer(name=act_name)) self.add_module( "conv", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, kernel_size=1, stride=1, padding=0, groups=1, bias=False, dilation=1, ), ) self.add_module( "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) )
[docs] class FactorizedReduceBlock(torch.nn.Module): """ Down-sampling the feature by 2 using stride. The length along each spatial dimension must be a multiple of 2. """
[docs] def __init__( self, in_channel: int, out_channel: int, spatial_dims: int = 3, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: in_channel: number of input channels out_channel: number of output channels. spatial_dims: number of spatial dimensions. act_name: activation layer type and arguments. norm_name: feature normalization type and arguments. """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel self._spatial_dims = spatial_dims if self._spatial_dims not in (2, 3): raise ValueError("spatial_dims must be 2 or 3.") conv_type = Conv[Conv.CONV, self._spatial_dims] self.act = get_act_layer(name=act_name) self.conv_1 = conv_type( in_channels=self._in_channel, out_channels=self._out_channel // 2, kernel_size=1, stride=2, padding=0, groups=1, bias=False, dilation=1, ) self.conv_2 = conv_type( in_channels=self._in_channel, out_channels=self._out_channel - self._out_channel // 2, kernel_size=1, stride=2, padding=0, groups=1, bias=False, dilation=1, ) self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ The length along each spatial dimension must be a multiple of 2. """ x = self.act(x) if self._spatial_dims == 3: out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) else: out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) out = self.norm(out) return out
[docs] class P3DActiConvNormBlock(torch.nn.Sequential): """ -- (act) -- (conv) -- (norm) -- """
[docs] def __init__( self, in_channel: int, out_channel: int, kernel_size: int, padding: int, mode: int = 0, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: in_channel: number of input channels. out_channel: number of output channels. kernel_size: kernel size to be expanded to 3D. padding: padding size to be expanded to 3D. mode: mode for the anisotropic kernels: - 0: ``(k, k, 1)``, ``(1, 1, k)``, - 1: ``(k, 1, k)``, ``(1, k, 1)``, - 2: ``(1, k, k)``. ``(k, 1, 1)``. act_name: activation layer type and arguments. norm_name: feature normalization type and arguments. """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel self._p3dmode = int(mode) conv_type = Conv[Conv.CONV, 3] if self._p3dmode == 0: # (k, k, 1), (1, 1, k) kernel_size0 = (kernel_size, kernel_size, 1) kernel_size1 = (1, 1, kernel_size) padding0 = (padding, padding, 0) padding1 = (0, 0, padding) elif self._p3dmode == 1: # (k, 1, k), (1, k, 1) kernel_size0 = (kernel_size, 1, kernel_size) kernel_size1 = (1, kernel_size, 1) padding0 = (padding, 0, padding) padding1 = (0, padding, 0) elif self._p3dmode == 2: # (1, k, k), (k, 1, 1) kernel_size0 = (1, kernel_size, kernel_size) kernel_size1 = (kernel_size, 1, 1) padding0 = (0, padding, padding) padding1 = (padding, 0, 0) else: raise ValueError("`mode` must be 0, 1, or 2.") self.add_module("acti", get_act_layer(name=act_name)) self.add_module( "conv", conv_type( in_channels=self._in_channel, out_channels=self._in_channel, kernel_size=kernel_size0, stride=1, padding=padding0, groups=1, bias=False, dilation=1, ), ) self.add_module( "conv_1", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, kernel_size=kernel_size1, stride=1, padding=padding1, groups=1, bias=False, dilation=1, ), ) self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel))
[docs] class ActiConvNormBlock(torch.nn.Sequential): """ -- (Acti) -- (Conv) -- (Norm) -- """
[docs] def __init__( self, in_channel: int, out_channel: int, kernel_size: int = 3, padding: int = 1, spatial_dims: int = 3, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), ): """ Args: in_channel: number of input channels. out_channel: number of output channels. kernel_size: kernel size of the convolution. padding: padding size of the convolution. spatial_dims: number of spatial dimensions. act_name: activation layer type and arguments. norm_name: feature normalization type and arguments. """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel self._spatial_dims = spatial_dims conv_type = Conv[Conv.CONV, self._spatial_dims] self.add_module("acti", get_act_layer(name=act_name)) self.add_module( "conv", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, kernel_size=kernel_size, stride=1, padding=padding, groups=1, bias=False, dilation=1, ), ) self.add_module( "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) )