# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.layers.convutils import same_padding
from monai.networks.layers.factories import Conv, Dropout, Norm
SUPPORTED_NORM = {
'batch': lambda spatial_dims: Norm[Norm.BATCH, spatial_dims],
'instance': lambda spatial_dims: Norm[Norm.INSTANCE, spatial_dims],
}
SUPPORTED_ACTI = {'relu': nn.ReLU, 'prelu': nn.PReLU, 'relu6': nn.ReLU6}
DEFAULT_LAYER_PARAMS_3D = (
# initial conv layer
{'name': 'conv_0', 'n_features': 16, 'kernel_size': 3},
# residual blocks
{'name': 'res_1', 'n_features': 16, 'kernels': (3, 3), 'repeat': 3},
{'name': 'res_2', 'n_features': 32, 'kernels': (3, 3), 'repeat': 3},
{'name': 'res_3', 'n_features': 64, 'kernels': (3, 3), 'repeat': 3},
# final conv layers
{'name': 'conv_1', 'n_features': 80, 'kernel_size': 1},
{'name': 'conv_2', 'kernel_size': 1},
)
[docs]class ConvNormActi(nn.Module):
def __init__(self,
spatial_dims,
in_channels,
out_channels,
kernel_size,
norm_type=None,
acti_type=None,
dropout_prob=None):
super(ConvNormActi, self).__init__()
layers = nn.ModuleList()
conv_type = Conv[Conv.CONV, spatial_dims]
padding_size = same_padding(kernel_size)
conv = conv_type(in_channels, out_channels, kernel_size, padding=padding_size)
layers.append(conv)
if norm_type is not None:
layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(out_channels))
if acti_type is not None:
layers.append(SUPPORTED_ACTI[acti_type](inplace=True))
if dropout_prob is not None:
dropout_type = Dropout[Dropout.DROPOUT, spatial_dims]
layers.append(dropout_type(p=dropout_prob))
self.layers = nn.Sequential(*layers)
[docs] def forward(self, x):
return self.layers(x)
[docs]class HighResBlock(nn.Module):
def __init__(self,
spatial_dims,
in_channels,
out_channels,
kernels=(3, 3),
dilation=1,
norm_type='instance',
acti_type='relu',
channel_matching='pad'):
"""
Args:
kernels (list of int): each integer k in `kernels` corresponds to a convolution layer with kernel size k.
channel_matching ('pad'|'project'): handling residual branch and conv branch channel mismatches
with either zero padding ('pad') or a trainable conv with kernel size 1 ('project').
"""
super(HighResBlock, self).__init__()
conv_type = Conv[Conv.CONV, spatial_dims]
self.project, self.pad = None, None
if in_channels != out_channels:
if channel_matching not in ('pad', 'project'):
raise ValueError('channel matching must be pad or project, got {}.'.format(channel_matching))
if channel_matching == 'project':
self.project = conv_type(in_channels, out_channels, kernel_size=1)
if channel_matching == 'pad':
if in_channels > out_channels:
raise ValueError('in_channels > out_channels is incompatible with `channel_matching=pad`.')
pad_1 = (out_channels - in_channels) // 2
pad_2 = out_channels - in_channels - pad_1
pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0]
self.pad = lambda input: F.pad(input, pad)
layers = nn.ModuleList()
_in_chns, _out_chns = in_channels, out_channels
for kernel_size in kernels:
layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(_in_chns))
layers.append(SUPPORTED_ACTI[acti_type](inplace=True))
layers.append(
conv_type(_in_chns,
_out_chns,
kernel_size,
padding=same_padding(kernel_size, dilation),
dilation=dilation))
_in_chns = _out_chns
self.layers = nn.Sequential(*layers)
[docs] def forward(self, x):
x_conv = self.layers(x)
if self.project is not None:
return x_conv + self.project(x)
if self.pad is not None:
return x_conv + self.pad(x)
return x_conv + x
[docs]class HighResNet(nn.Module):
"""
Reimplementation of highres3dnet based on
Li et al., "On the compactness, efficiency, and representation of 3D
convolutional networks: Brain parcellation as a pretext task", IPMI '17
Adapted from:
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/network/highres3dnet.py
https://github.com/fepegar/highresnet
Args:
spatial_dims (int): number of spatial dimensions of the input image.
in_channels (int): number of input channels.
out_channels (int): number of output channels.
norm_type ('batch'|'instance'): feature normalisation with batchnorm or instancenorm.
acti_type ('relu'|'prelu'|'relu6'): non-linear activation using ReLU or PReLU.
dropout_prob (float): probability of the feature map to be zeroed
(only applies to the penultimate conv layer).
layer_params (a list of dictionaries): specifying key parameters of each layer/block.
"""
def __init__(self,
spatial_dims=3,
in_channels=1,
out_channels=1,
norm_type='batch',
acti_type='relu',
dropout_prob=None,
layer_params=DEFAULT_LAYER_PARAMS_3D):
super(HighResNet, self).__init__()
blocks = nn.ModuleList()
# intial conv layer
params = layer_params[0]
_in_chns, _out_chns = in_channels, params['n_features']
blocks.append(
ConvNormActi(spatial_dims,
_in_chns,
_out_chns,
kernel_size=params['kernel_size'],
norm_type=norm_type,
acti_type=acti_type,
dropout_prob=None))
# residual blocks
for (idx, params) in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers.
_in_chns, _out_chns = _out_chns, params['n_features']
_dilation = 2**idx
for _ in range(params['repeat']):
blocks.append(
HighResBlock(spatial_dims,
_in_chns,
_out_chns,
params['kernels'],
dilation=_dilation,
norm_type=norm_type,
acti_type=acti_type))
_in_chns = _out_chns
# final conv layers
params = layer_params[-2]
_in_chns, _out_chns = _out_chns, params['n_features']
blocks.append(
ConvNormActi(spatial_dims,
_in_chns,
_out_chns,
kernel_size=params['kernel_size'],
norm_type=norm_type,
acti_type=acti_type,
dropout_prob=dropout_prob))
params = layer_params[-1]
_in_chns = _out_chns
blocks.append(
ConvNormActi(spatial_dims,
_in_chns,
out_channels,
kernel_size=params['kernel_size'],
norm_type=norm_type,
acti_type=None,
dropout_prob=None))
self.blocks = nn.Sequential(*blocks)
[docs] def forward(self, x):
return self.blocks(x)