Source code for monai.networks.blocks.aspp

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers import same_padding
from monai.networks.layers.factories import Conv


[docs] class SimpleASPP(nn.Module): """ A simplified version of the atrous spatial pyramid pooling (ASPP) module. Chen et al., Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. https://arxiv.org/abs/1802.02611 Wang et al., A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images. https://ieeexplore.ieee.org/document/9109297 """
[docs] def __init__( self, spatial_dims: int, in_channels: int, conv_out_channels: int, kernel_sizes: Sequence[int] = (1, 3, 3, 3), dilations: Sequence[int] = (1, 2, 4, 6), norm_type: tuple | str | None = "BATCH", acti_type: tuple | str | None = "LEAKYRELU", bias: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. conv_out_channels: number of output channels of each atrous conv. The final number of output channels is conv_out_channels * len(kernel_sizes). kernel_sizes: a sequence of four convolutional kernel sizes. Defaults to (1, 3, 3, 3) for four (dilated) convolutions. dilations: a sequence of four convolutional dilation parameters. Defaults to (1, 2, 4, 6) for four (dilated) convolutions. norm_type: final kernel-size-one convolution normalization type. Defaults to batch norm. acti_type: final kernel-size-one convolution activation type. Defaults to leaky ReLU. bias: whether to have a bias term in convolution blocks. Defaults to False. According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_, if a conv layer is directly followed by a batch norm layer, bias should be False. Raises: ValueError: When ``kernel_sizes`` length differs from ``dilations``. See also: :py:class:`monai.networks.layers.Act` :py:class:`monai.networks.layers.Conv` :py:class:`monai.networks.layers.Norm` """ super().__init__() if len(kernel_sizes) != len(dilations): raise ValueError( "kernel_sizes and dilations length must match, " f"got kernel_sizes={len(kernel_sizes)} dilations={len(dilations)}." ) pads = tuple(same_padding(k, d) for k, d in zip(kernel_sizes, dilations)) self.convs = nn.ModuleList() for k, d, p in zip(kernel_sizes, dilations, pads): _conv = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=conv_out_channels, kernel_size=k, dilation=d, padding=p ) self.convs.append(_conv) out_channels = conv_out_channels * len(pads) # final conv. output channels self.conv_k1 = Convolution( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=1, act=acti_type, norm=norm_type, bias=bias, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape (batch, channel, spatial_1[, spatial_2, ...]). """ x_out = torch.cat([conv(x) for conv in self.convs], dim=1) x_out = self.conv_k1(x_out) return x_out