Source code for monai.networks.blocks.segresnet_block

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.upsample import UpSample
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import InterpolateMode, UpsampleMode


def get_conv_layer(
    spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False
):
    return Convolution(
        spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True
    )


def get_upsample_layer(
    spatial_dims: int, in_channels: int, upsample_mode: UpsampleMode | str = "nontrainable", scale_factor: int = 2
):
    return UpSample(
        spatial_dims=spatial_dims,
        in_channels=in_channels,
        out_channels=in_channels,
        scale_factor=scale_factor,
        mode=upsample_mode,
        interp_mode=InterpolateMode.LINEAR,
        align_corners=False,
    )


[docs] class ResBlock(nn.Module): """ ResBlock employs skip connection and two convolution blocks and is used in SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization <https://arxiv.org/pdf/1810.11654.pdf>`_. """
[docs] def __init__( self, spatial_dims: int, in_channels: int, norm: tuple | str, kernel_size: int = 3, act: tuple | str = ("RELU", {"inplace": True}), ) -> None: """ Args: spatial_dims: number of spatial dimensions, could be 1, 2 or 3. in_channels: number of input channels. norm: feature normalization type and arguments. kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3. act: activation type and arguments. Defaults to ``RELU``. """ super().__init__() if kernel_size % 2 != 1: raise AssertionError("kernel_size should be an odd number.") self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) self.act = get_act_layer(act) self.conv1 = get_conv_layer( spatial_dims, in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size ) self.conv2 = get_conv_layer( spatial_dims, in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size )
[docs] def forward(self, x): identity = x x = self.norm1(x) x = self.act(x) x = self.conv1(x) x = self.norm2(x) x = self.act(x) x = self.conv2(x) x += identity return x