# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
from monai.networks.blocks import ADN, Convolution
from monai.networks.layers.simplelayers import ChannelPad
from monai.utils import ChannelMatching
__all__ = ["HighResBlock", "HighResNet"]
DEFAULT_LAYER_PARAMS_3D = (
# initial conv layer
{"name": "conv_0", "n_features": 16, "kernel_size": 3},
# residual blocks
{"name": "res_1", "n_features": 16, "kernels": (3, 3), "repeat": 3},
{"name": "res_2", "n_features": 32, "kernels": (3, 3), "repeat": 3},
{"name": "res_3", "n_features": 64, "kernels": (3, 3), "repeat": 3},
# final conv layers
{"name": "conv_1", "n_features": 80, "kernel_size": 1},
{"name": "conv_2", "kernel_size": 1},
)
[docs]class HighResBlock(nn.Module):
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernels: Sequence[int] = (3, 3),
dilation: Sequence[int] | int = 1,
norm_type: tuple | str = ("batch", {"affine": True}),
acti_type: tuple | str = ("relu", {"inplace": True}),
bias: bool = False,
channel_matching: ChannelMatching | str = ChannelMatching.PAD,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of input channels.
out_channels: number of output channels.
kernels: each integer k in `kernels` corresponds to a convolution layer with kernel size k.
dilation: spacing between kernel elements.
norm_type: feature normalization type and arguments.
Defaults to ``("batch", {"affine": True})``.
acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``}
Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``.
bias: whether to have a bias term in convolution blocks. Defaults to False.
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
channel_matching: {``"pad"``, ``"project"``}
Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.
- ``"pad"``: with zero padding.
- ``"project"``: with a trainable conv with kernel size one.
Raises:
ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values.
"""
super().__init__()
self.chn_pad = ChannelPad(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, mode=channel_matching
)
layers = nn.ModuleList()
_in_chns, _out_chns = in_channels, out_channels
for kernel_size in kernels:
layers.append(
ADN(ordering="NA", in_channels=_in_chns, act=acti_type, norm=norm_type, norm_dim=spatial_dims)
)
layers.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=_in_chns,
out_channels=_out_chns,
kernel_size=kernel_size,
dilation=dilation,
bias=bias,
conv_only=True,
)
)
_in_chns = _out_chns
self.layers = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
x_conv: torch.Tensor = self.layers(x)
return x_conv + torch.as_tensor(self.chn_pad(x))
[docs]class HighResNet(nn.Module):
"""
Reimplementation of highres3dnet based on
Li et al., "On the compactness, efficiency, and representation of 3D
convolutional networks: Brain parcellation as a pretext task", IPMI '17
Adapted from:
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/network/highres3dnet.py
https://github.com/fepegar/highresnet
Args:
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of input channels.
out_channels: number of output channels.
norm_type: feature normalization type and arguments.
Defaults to ``("batch", {"affine": True})``.
acti_type: activation type and arguments.
Defaults to ``("relu", {"inplace": True})``.
dropout_prob: probability of the feature map to be zeroed
(only applies to the penultimate conv layer).
bias: whether to have a bias term in convolution blocks. Defaults to False.
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
layer_params: specifying key parameters of each layer/block.
channel_matching: {``"pad"``, ``"project"``}
Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.
- ``"pad"``: with zero padding.
- ``"project"``: with a trainable conv with kernel size one.
"""
def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 1,
norm_type: str | tuple = ("batch", {"affine": True}),
acti_type: str | tuple = ("relu", {"inplace": True}),
dropout_prob: tuple | str | float | None = 0.0,
bias: bool = False,
layer_params: Sequence[dict] = DEFAULT_LAYER_PARAMS_3D,
channel_matching: ChannelMatching | str = ChannelMatching.PAD,
) -> None:
super().__init__()
blocks = nn.ModuleList()
# initial conv layer
params = layer_params[0]
_in_chns, _out_chns = in_channels, params["n_features"]
blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=_in_chns,
out_channels=_out_chns,
kernel_size=params["kernel_size"],
adn_ordering="NA",
act=acti_type,
norm=norm_type,
bias=bias,
)
)
# residual blocks
for idx, params in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers.
_in_chns, _out_chns = _out_chns, params["n_features"]
_dilation = 2**idx
for _ in range(params["repeat"]):
blocks.append(
HighResBlock(
spatial_dims=spatial_dims,
in_channels=_in_chns,
out_channels=_out_chns,
kernels=params["kernels"],
dilation=_dilation,
norm_type=norm_type,
acti_type=acti_type,
bias=bias,
channel_matching=channel_matching,
)
)
_in_chns = _out_chns
# final conv layers
params = layer_params[-2]
_in_chns, _out_chns = _out_chns, params["n_features"]
blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=_in_chns,
out_channels=_out_chns,
kernel_size=params["kernel_size"],
adn_ordering="NAD",
act=acti_type,
norm=norm_type,
bias=bias,
dropout=dropout_prob,
)
)
params = layer_params[-1]
_in_chns = _out_chns
blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=_in_chns,
out_channels=out_channels,
kernel_size=params["kernel_size"],
adn_ordering="NAD",
act=acti_type,
norm=norm_type,
bias=bias,
dropout=dropout_prob,
)
)
self.blocks = nn.Sequential(*blocks)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(self.blocks(x))