# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
[docs]class UnetrUpBlock(nn.Module):
"""
An upsampling module that can be used for UNETR: "Hatamizadeh et al.,
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Sequence[int] | int,
upsample_kernel_size: Sequence[int] | int,
norm_name: tuple | str,
res_block: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
res_block: bool argument to determine if residual block is used.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
if res_block:
self.conv_block = UnetResBlock(
spatial_dims,
out_channels + out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
norm_name=norm_name,
)
else:
self.conv_block = UnetBasicBlock( # type: ignore
spatial_dims,
out_channels + out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
norm_name=norm_name,
)
[docs] def forward(self, inp, skip):
# number of channels for skip should equals to out_channels
out = self.transp_conv(inp)
out = torch.cat((out, skip), dim=1)
out = self.conv_block(out)
return out
[docs]class UnetrPrUpBlock(nn.Module):
"""
A projection upsampling module that can be used for UNETR: "Hatamizadeh et al.,
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
num_layer: int,
kernel_size: Sequence[int] | int,
stride: Sequence[int] | int,
upsample_kernel_size: Sequence[int] | int,
norm_name: tuple | str,
conv_block: bool = False,
res_block: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
num_layer: number of upsampling blocks.
kernel_size: convolution kernel size.
stride: convolution stride.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
conv_block: bool argument to determine if convolutional block is used.
res_block: bool argument to determine if residual block is used.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv_init = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
if conv_block:
if res_block:
self.blocks = nn.ModuleList(
[
nn.Sequential(
get_conv_layer(
spatial_dims,
out_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
),
UnetResBlock(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
),
)
for i in range(num_layer)
]
)
else:
self.blocks = nn.ModuleList(
[
nn.Sequential(
get_conv_layer(
spatial_dims,
out_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
),
UnetBasicBlock(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
),
)
for i in range(num_layer)
]
)
else:
self.blocks = nn.ModuleList(
[
get_conv_layer(
spatial_dims,
out_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
for i in range(num_layer)
]
)
[docs] def forward(self, x):
x = self.transp_conv_init(x)
for blk in self.blocks:
x = blk(x)
return x
[docs]class UnetrBasicBlock(nn.Module):
"""
A CNN module that can be used for UNETR, based on: "Hatamizadeh et al.,
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Sequence[int] | int,
stride: Sequence[int] | int,
norm_name: tuple | str,
res_block: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
res_block: bool argument to determine if residual block is used.
"""
super().__init__()
if res_block:
self.layer = UnetResBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
)
else:
self.layer = UnetBasicBlock( # type: ignore
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
)
[docs] def forward(self, inp):
return self.layer(inp)