# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import torch
import torch.nn as nn
from monai.networks.blocks import Convolution
from monai.networks.layers.factories import Act, Conv, Norm, Pool, split_args
[docs]class ChannelSELayer(nn.Module):
"""
Re-implementation of the Squeeze-and-Excitation block based on:
"Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
r: int = 2,
acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}),
acti_type_2: tuple[str, dict] | str = "sigmoid",
add_residual: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels.
r: the reduction ratio r in the paper. Defaults to 2.
acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``.
acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".
Raises:
ValueError: When ``r`` is nonpositive or larger than ``in_channels``.
See also:
:py:class:`monai.networks.layers.Act`
"""
super().__init__()
self.add_residual = add_residual
pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]
self.avg_pool = pool_type(1) # spatial size (1, 1, ...)
channels = int(in_channels // r)
if channels <= 0:
raise ValueError(f"r must be positive and smaller than in_channels, got r={r} in_channels={in_channels}.")
act_1, act_1_args = split_args(acti_type_1)
act_2, act_2_args = split_args(acti_type_2)
self.fc = nn.Sequential(
nn.Linear(in_channels, channels, bias=True),
Act[act_1](**act_1_args),
nn.Linear(channels, in_channels, bias=True),
Act[act_2](**act_2_args),
)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
"""
b, c = x.shape[:2]
y: torch.Tensor = self.avg_pool(x).view(b, c)
y = self.fc(y).view([b, c] + [1] * (x.ndim - 2))
result = x * y
# Residual connection is moved here instead of providing an override of forward in ResidualSELayer since
# Torchscript has an issue with using super().
if self.add_residual:
result += x
return result
[docs]class ResidualSELayer(ChannelSELayer):
"""
A "squeeze-and-excitation"-like layer with a residual connection::
--+-- SE --o--
| |
+--------+
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
r: int = 2,
acti_type_1: tuple[str, dict] | str = "leakyrelu",
acti_type_2: tuple[str, dict] | str = "relu",
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels.
r: the reduction ratio r in the paper. Defaults to 2.
acti_type_1: defaults to "leakyrelu".
acti_type_2: defaults to "relu".
See also:
:py:class:`monai.networks.blocks.ChannelSELayer`
"""
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
r=r,
acti_type_1=acti_type_1,
acti_type_2=acti_type_2,
add_residual=True,
)
[docs]class SEBlock(nn.Module):
"""
Residual module enhanced with Squeeze-and-Excitation::
----+- conv1 -- conv2 -- conv3 -- SE -o---
| |
+---(channel project if needed)----+
Re-implementation of the SE-Resnet block based on:
"Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
n_chns_1: int,
n_chns_2: int,
n_chns_3: int,
conv_param_1: dict | None = None,
conv_param_2: dict | None = None,
conv_param_3: dict | None = None,
project: Convolution | None = None,
r: int = 2,
acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}),
acti_type_2: tuple[str, dict] | str = "sigmoid",
acti_type_final: tuple[str, dict] | str | None = ("relu", {"inplace": True}),
):
"""
Args:
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
in_channels: number of input channels.
n_chns_1: number of output channels in the 1st convolution.
n_chns_2: number of output channels in the 2nd convolution.
n_chns_3: number of output channels in the 3rd convolution.
conv_param_1: additional parameters to the 1st convolution.
Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
conv_param_2: additional parameters to the 2nd convolution.
Defaults to ``{"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
conv_param_3: additional parameters to the 3rd convolution.
Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": None}``
project: in the case of residual chns and output chns doesn't match, a project
(Conv) layer/block is used to adjust the number of chns. In SENET, it is
consisted with a Conv layer as well as a Norm layer.
Defaults to None (chns are matchable) or a Conv layer with kernel size 1.
r: the reduction ratio r in the paper. Defaults to 2.
acti_type_1: activation type of the hidden squeeze layer. Defaults to "relu".
acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".
acti_type_final: activation type of the end of the block. Defaults to "relu".
See also:
:py:class:`monai.networks.blocks.ChannelSELayer`
"""
super().__init__()
if not conv_param_1:
conv_param_1 = {"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}
self.conv1 = Convolution(
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1
)
if not conv_param_2:
conv_param_2 = {"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}
self.conv2 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2)
if not conv_param_3:
conv_param_3 = {"kernel_size": 1, "norm": Norm.BATCH, "act": None}
self.conv3 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3)
self.se_layer = ChannelSELayer(
spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2
)
if project is None and in_channels != n_chns_3:
self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1)
elif project is None:
self.project = nn.Identity()
else:
self.project = project
if acti_type_final is not None:
act_final, act_final_args = split_args(acti_type_final)
self.act = Act[act_final](**act_final_args)
else:
self.act = nn.Identity()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
"""
residual = self.project(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.se_layer(x)
x += residual
x = self.act(x)
return x
[docs]class SEBottleneck(SEBlock):
"""
Bottleneck for SENet154.
"""
expansion = 4
def __init__(
self,
spatial_dims: int,
inplanes: int,
planes: int,
groups: int,
reduction: int,
stride: int = 1,
downsample: Convolution | None = None,
) -> None:
conv_param_1 = {
"strides": 1,
"kernel_size": 1,
"act": ("relu", {"inplace": True}),
"norm": Norm.BATCH,
"bias": False,
}
conv_param_2 = {
"strides": stride,
"kernel_size": 3,
"act": ("relu", {"inplace": True}),
"norm": Norm.BATCH,
"groups": groups,
"bias": False,
}
conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}
super().__init__(
spatial_dims=spatial_dims,
in_channels=inplanes,
n_chns_1=planes * 2,
n_chns_2=planes * 4,
n_chns_3=planes * 4,
conv_param_1=conv_param_1,
conv_param_2=conv_param_2,
conv_param_3=conv_param_3,
project=downsample,
r=reduction,
)
[docs]class SEResNetBottleneck(SEBlock):
"""
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
implementation and uses `strides=stride` in `conv1` and not in `conv2`
(the latter is used in the torchvision implementation of ResNet).
"""
expansion = 4
def __init__(
self,
spatial_dims: int,
inplanes: int,
planes: int,
groups: int,
reduction: int,
stride: int = 1,
downsample: Convolution | None = None,
) -> None:
conv_param_1 = {
"strides": stride,
"kernel_size": 1,
"act": ("relu", {"inplace": True}),
"norm": Norm.BATCH,
"bias": False,
}
conv_param_2 = {
"strides": 1,
"kernel_size": 3,
"act": ("relu", {"inplace": True}),
"norm": Norm.BATCH,
"groups": groups,
"bias": False,
}
conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}
super().__init__(
spatial_dims=spatial_dims,
in_channels=inplanes,
n_chns_1=planes,
n_chns_2=planes,
n_chns_3=planes * 4,
conv_param_1=conv_param_1,
conv_param_2=conv_param_2,
conv_param_3=conv_param_3,
project=downsample,
r=reduction,
)
[docs]class SEResNeXtBottleneck(SEBlock):
"""
ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
"""
expansion = 4
def __init__(
self,
spatial_dims: int,
inplanes: int,
planes: int,
groups: int,
reduction: int,
stride: int = 1,
downsample: Convolution | None = None,
base_width: int = 4,
) -> None:
conv_param_1 = {
"strides": 1,
"kernel_size": 1,
"act": ("relu", {"inplace": True}),
"norm": Norm.BATCH,
"bias": False,
}
conv_param_2 = {
"strides": stride,
"kernel_size": 3,
"act": ("relu", {"inplace": True}),
"norm": Norm.BATCH,
"groups": groups,
"bias": False,
}
conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}
width = math.floor(planes * (base_width / 64)) * groups
super().__init__(
spatial_dims=spatial_dims,
in_channels=inplanes,
n_chns_1=width,
n_chns_2=width,
n_chns_3=planes * 4,
conv_param_1=conv_param_1,
conv_param_2=conv_param_2,
conv_param_3=conv_param_3,
project=downsample,
r=reduction,
)