# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch.nn as nn
from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.upsample import UpSample
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import InterpolateMode, UpsampleMode
def get_conv_layer(
spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False
):
return Convolution(
spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True
)
def get_upsample_layer(
spatial_dims: int, in_channels: int, upsample_mode: UpsampleMode | str = "nontrainable", scale_factor: int = 2
):
return UpSample(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
scale_factor=scale_factor,
mode=upsample_mode,
interp_mode=InterpolateMode.LINEAR,
align_corners=False,
)
[docs]class ResBlock(nn.Module):
"""
ResBlock employs skip connection and two convolution blocks and is used
in SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
<https://arxiv.org/pdf/1810.11654.pdf>`_.
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
norm: tuple | str,
kernel_size: int = 3,
act: tuple | str = ("RELU", {"inplace": True}),
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
in_channels: number of input channels.
norm: feature normalization type and arguments.
kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.
act: activation type and arguments. Defaults to ``RELU``.
"""
super().__init__()
if kernel_size % 2 != 1:
raise AssertionError("kernel_size should be an odd number.")
self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
self.act = get_act_layer(act)
self.conv1 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels)
self.conv2 = get_conv_layer(spatial_dims, in_channels=in_channels, out_channels=in_channels)
[docs] def forward(self, x):
identity = x
x = self.norm1(x)
x = self.act(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x)
x += identity
return x