Source code for monai.networks.nets.attentionunet

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Norm

__all__ = ["AttentionUnet"]


class ConvBlock(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        strides: int = 1,
        dropout=0.0,
    ):
        super().__init__()
        layers = [
            Convolution(
                spatial_dims=spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                strides=strides,
                padding=None,
                adn_ordering="NDA",
                act="relu",
                norm=Norm.BATCH,
                dropout=dropout,
            ),
            Convolution(
                spatial_dims=spatial_dims,
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                strides=1,
                padding=None,
                adn_ordering="NDA",
                act="relu",
                norm=Norm.BATCH,
                dropout=dropout,
            ),
        ]
        self.conv = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_c: torch.Tensor = self.conv(x)
        return x_c


class UpConv(nn.Module):
    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0):
        super().__init__()
        self.up = Convolution(
            spatial_dims,
            in_channels,
            out_channels,
            strides=strides,
            kernel_size=kernel_size,
            act="relu",
            adn_ordering="NDA",
            norm=Norm.BATCH,
            dropout=dropout,
            is_transposed=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_u: torch.Tensor = self.up(x)
        return x_u


class AttentionBlock(nn.Module):
    def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0):
        super().__init__()
        self.W_g = nn.Sequential(
            Convolution(
                spatial_dims=spatial_dims,
                in_channels=f_g,
                out_channels=f_int,
                kernel_size=1,
                strides=1,
                padding=0,
                dropout=dropout,
                conv_only=True,
            ),
            Norm[Norm.BATCH, spatial_dims](f_int),
        )

        self.W_x = nn.Sequential(
            Convolution(
                spatial_dims=spatial_dims,
                in_channels=f_l,
                out_channels=f_int,
                kernel_size=1,
                strides=1,
                padding=0,
                dropout=dropout,
                conv_only=True,
            ),
            Norm[Norm.BATCH, spatial_dims](f_int),
        )

        self.psi = nn.Sequential(
            Convolution(
                spatial_dims=spatial_dims,
                in_channels=f_int,
                out_channels=1,
                kernel_size=1,
                strides=1,
                padding=0,
                dropout=dropout,
                conv_only=True,
            ),
            Norm[Norm.BATCH, spatial_dims](1),
            nn.Sigmoid(),
        )

        self.relu = nn.ReLU()

    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi: torch.Tensor = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x * psi


class AttentionLayer(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        submodule: nn.Module,
        up_kernel_size=3,
        strides=2,
        dropout=0.0,
    ):
        super().__init__()
        self.attention = AttentionBlock(
            spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2
        )
        self.upconv = UpConv(
            spatial_dims=spatial_dims,
            in_channels=out_channels,
            out_channels=in_channels,
            strides=strides,
            kernel_size=up_kernel_size,
        )
        self.merge = Convolution(
            spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout
        )
        self.submodule = submodule

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        fromlower = self.upconv(self.submodule(x))
        att = self.attention(g=fromlower, x=x)
        att_m: torch.Tensor = self.merge(torch.cat((att, fromlower), dim=1))
        return att_m


[docs] class AttentionUnet(nn.Module): """ Attention Unet based on Otkay et al. "Attention U-Net: Learning Where to Look for the Pancreas" https://arxiv.org/abs/1804.03999 Args: spatial_dims: number of spatial dimensions of the input image. in_channels: number of the input channel. out_channels: number of the output classes. channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. strides (Sequence[int]): stride to use for convolutions. kernel_size: convolution kernel size. up_kernel_size: convolution kernel size for transposed convolution layers. dropout: dropout ratio. Defaults to no dropout. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int], strides: Sequence[int], kernel_size: Sequence[int] | int = 3, up_kernel_size: Sequence[int] | int = 3, dropout: float = 0.0, ): super().__init__() self.dimensions = spatial_dims self.in_channels = in_channels self.out_channels = out_channels self.channels = channels self.strides = strides self.kernel_size = kernel_size self.dropout = dropout head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) reduce_channels = Convolution( spatial_dims=spatial_dims, in_channels=channels[0], out_channels=out_channels, kernel_size=1, strides=1, padding=0, conv_only=True, ) self.up_kernel_size = up_kernel_size def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: if len(channels) > 2: subblock = _create_block(channels[1:], strides[1:]) return AttentionLayer( spatial_dims=spatial_dims, in_channels=channels[0], out_channels=channels[1], submodule=nn.Sequential( ConvBlock( spatial_dims=spatial_dims, in_channels=channels[0], out_channels=channels[1], strides=strides[0], dropout=self.dropout, ), subblock, ), up_kernel_size=self.up_kernel_size, strides=strides[0], dropout=dropout, ) else: # the next layer is the bottom so stop recursion, # create the bottom layer as the subblock for this layer return self._get_bottom_layer(channels[0], channels[1], strides[0]) encdec = _create_block(self.channels, self.strides) self.model = nn.Sequential(head, encdec, reduce_channels) def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module: return AttentionLayer( spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, submodule=ConvBlock( spatial_dims=self.dimensions, in_channels=in_channels, out_channels=out_channels, strides=strides, dropout=self.dropout, ), up_kernel_size=self.up_kernel_size, strides=strides, dropout=self.dropout, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x_m: torch.Tensor = self.model(x) return x_m