Source code for monai.networks.blocks.aspp

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers import same_padding
from monai.networks.layers.factories import Act, Conv, Norm

[docs]class SimpleASPP(nn.Module): """ A simplified version of the atrous spatial pyramid pooling (ASPP) module. Chen et al., Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. Wang et al., A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images. """ def __init__( self, spatial_dims: int, in_channels: int, conv_out_channels: int, kernel_sizes: Sequence[int] = (1, 3, 3, 3), dilations: Sequence[int] = (1, 2, 4, 6), norm_type=Norm.BATCH, acti_type=Act.LEAKYRELU, ) -> None: """ Args: spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. conv_out_channels: number of output channels of each atrous conv. The final number of output channels is conv_out_channels * len(kernel_sizes). kernel_sizes: a sequence of four convolutional kernel sizes. Defaults to (1, 3, 3, 3) for four (dilated) convolutions. dilations: a sequence of four convolutional dilation parameters. Defaults to (1, 2, 4, 6) for four (dilated) convolutions. norm_type: final kernel-size-one convolution normalization type. Defaults to batch norm. acti_type: final kernel-size-one convolution activation type. Defaults to leaky ReLU. Raises: ValueError: When ``kernel_sizes`` length differs from ``dilations``. See also: :py:class:`monai.networks.layers.Act` :py:class:`monai.networks.layers.Conv` :py:class:`monai.networks.layers.Norm` """ super().__init__() if len(kernel_sizes) != len(dilations): raise ValueError( "kernel_sizes and dilations length must match, " f"got kernel_sizes={len(kernel_sizes)} dilations={len(dilations)}." ) pads = tuple(same_padding(k, d) for k, d in zip(kernel_sizes, dilations)) self.convs = nn.ModuleList() for k, d, p in zip(kernel_sizes, dilations, pads): _conv = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=conv_out_channels, kernel_size=k, dilation=d, padding=p ) self.convs.append(_conv) out_channels = conv_out_channels * len(pads) # final conv. output channels self.conv_k1 = Convolution( dimensions=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=1, act=acti_type, norm=norm_type, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape (batch, channel, spatial_1[, spatial_2, ...]). """ x_out =[conv(x) for conv in self.convs], dim=1) x_out = self.conv_k1(x_out) return x_out