Source code for monai.networks.blocks.squeeze_and_excitation

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math

import torch
import torch.nn as nn

from monai.networks.blocks import Convolution
from monai.networks.layers.factories import Act, Conv, Norm, Pool, split_args


[docs] class ChannelSELayer(nn.Module): """ Re-implementation of the Squeeze-and-Excitation block based on: "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507". """
[docs] def __init__( self, spatial_dims: int, in_channels: int, r: int = 2, acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}), acti_type_2: tuple[str, dict] | str = "sigmoid", add_residual: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. r: the reduction ratio r in the paper. Defaults to 2. acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``. acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid". Raises: ValueError: When ``r`` is nonpositive or larger than ``in_channels``. See also: :py:class:`monai.networks.layers.Act` """ super().__init__() self.add_residual = add_residual pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims] self.avg_pool = pool_type(1) # spatial size (1, 1, ...) channels = int(in_channels // r) if channels <= 0: raise ValueError(f"r must be positive and smaller than in_channels, got r={r} in_channels={in_channels}.") act_1, act_1_args = split_args(acti_type_1) act_2, act_2_args = split_args(acti_type_2) self.fc = nn.Sequential( nn.Linear(in_channels, channels, bias=True), Act[act_1](**act_1_args), nn.Linear(channels, in_channels, bias=True), Act[act_2](**act_2_args), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]). """ b, c = x.shape[:2] y: torch.Tensor = self.avg_pool(x).view(b, c) y = self.fc(y).view([b, c] + [1] * (x.ndim - 2)) result = x * y # Residual connection is moved here instead of providing an override of forward in ResidualSELayer since # Torchscript has an issue with using super(). if self.add_residual: result += x return result
[docs] class ResidualSELayer(ChannelSELayer): """ A "squeeze-and-excitation"-like layer with a residual connection:: --+-- SE --o-- | | +--------+ """
[docs] def __init__( self, spatial_dims: int, in_channels: int, r: int = 2, acti_type_1: tuple[str, dict] | str = "leakyrelu", acti_type_2: tuple[str, dict] | str = "relu", ) -> None: """ Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. r: the reduction ratio r in the paper. Defaults to 2. acti_type_1: defaults to "leakyrelu". acti_type_2: defaults to "relu". See also: :py:class:`monai.networks.blocks.ChannelSELayer` """ super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2, add_residual=True, )
[docs] class SEBlock(nn.Module): """ Residual module enhanced with Squeeze-and-Excitation:: ----+- conv1 -- conv2 -- conv3 -- SE -o--- | | +---(channel project if needed)----+ Re-implementation of the SE-Resnet block based on: "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507". """
[docs] def __init__( self, spatial_dims: int, in_channels: int, n_chns_1: int, n_chns_2: int, n_chns_3: int, conv_param_1: dict | None = None, conv_param_2: dict | None = None, conv_param_3: dict | None = None, project: Convolution | None = None, r: int = 2, acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}), acti_type_2: tuple[str, dict] | str = "sigmoid", acti_type_final: tuple[str, dict] | str | None = ("relu", {"inplace": True}), ): """ Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. n_chns_1: number of output channels in the 1st convolution. n_chns_2: number of output channels in the 2nd convolution. n_chns_3: number of output channels in the 3rd convolution. conv_param_1: additional parameters to the 1st convolution. Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}`` conv_param_2: additional parameters to the 2nd convolution. Defaults to ``{"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}`` conv_param_3: additional parameters to the 3rd convolution. Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": None}`` project: in the case of residual chns and output chns doesn't match, a project (Conv) layer/block is used to adjust the number of chns. In SENET, it is consisted with a Conv layer as well as a Norm layer. Defaults to None (chns are matchable) or a Conv layer with kernel size 1. r: the reduction ratio r in the paper. Defaults to 2. acti_type_1: activation type of the hidden squeeze layer. Defaults to "relu". acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid". acti_type_final: activation type of the end of the block. Defaults to "relu". See also: :py:class:`monai.networks.blocks.ChannelSELayer` """ super().__init__() if not conv_param_1: conv_param_1 = {"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})} self.conv1 = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1 ) if not conv_param_2: conv_param_2 = {"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})} self.conv2 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2) if not conv_param_3: conv_param_3 = {"kernel_size": 1, "norm": Norm.BATCH, "act": None} self.conv3 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3) self.se_layer = ChannelSELayer( spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2 ) if project is None and in_channels != n_chns_3: self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1) elif project is None: self.project = nn.Identity() else: self.project = project if acti_type_final is not None: act_final, act_final_args = split_args(acti_type_final) self.act = Act[act_final](**act_final_args) else: self.act = nn.Identity()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]). """ residual = self.project(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.se_layer(x) x += residual x = self.act(x) return x
[docs] class SEBottleneck(SEBlock): """ Bottleneck for SENet154. """ expansion = 4 def __init__( self, spatial_dims: int, inplanes: int, planes: int, groups: int, reduction: int, stride: int = 1, downsample: Convolution | None = None, ) -> None: conv_param_1 = { "strides": 1, "kernel_size": 1, "act": ("relu", {"inplace": True}), "norm": Norm.BATCH, "bias": False, } conv_param_2 = { "strides": stride, "kernel_size": 3, "act": ("relu", {"inplace": True}), "norm": Norm.BATCH, "groups": groups, "bias": False, } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=planes * 2, n_chns_2=planes * 4, n_chns_3=planes * 4, conv_param_1=conv_param_1, conv_param_2=conv_param_2, conv_param_3=conv_param_3, project=downsample, r=reduction, )
[docs] class SEResNetBottleneck(SEBlock): """ ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe implementation and uses `strides=stride` in `conv1` and not in `conv2` (the latter is used in the torchvision implementation of ResNet). """ expansion = 4 def __init__( self, spatial_dims: int, inplanes: int, planes: int, groups: int, reduction: int, stride: int = 1, downsample: Convolution | None = None, ) -> None: conv_param_1 = { "strides": stride, "kernel_size": 1, "act": ("relu", {"inplace": True}), "norm": Norm.BATCH, "bias": False, } conv_param_2 = { "strides": 1, "kernel_size": 3, "act": ("relu", {"inplace": True}), "norm": Norm.BATCH, "groups": groups, "bias": False, } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=planes, n_chns_2=planes, n_chns_3=planes * 4, conv_param_1=conv_param_1, conv_param_2=conv_param_2, conv_param_3=conv_param_3, project=downsample, r=reduction, )
[docs] class SEResNeXtBottleneck(SEBlock): """ ResNeXt bottleneck type C with a Squeeze-and-Excitation module. """ expansion = 4 def __init__( self, spatial_dims: int, inplanes: int, planes: int, groups: int, reduction: int, stride: int = 1, downsample: Convolution | None = None, base_width: int = 4, ) -> None: conv_param_1 = { "strides": 1, "kernel_size": 1, "act": ("relu", {"inplace": True}), "norm": Norm.BATCH, "bias": False, } conv_param_2 = { "strides": stride, "kernel_size": 3, "act": ("relu", {"inplace": True}), "norm": Norm.BATCH, "groups": groups, "bias": False, } conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} width = math.floor(planes * (base_width / 64)) * groups super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=width, n_chns_2=width, n_chns_3=planes * 4, conv_param_1=conv_param_1, conv_param_2=conv_param_2, conv_param_3=conv_param_3, project=downsample, r=reduction, )