Source code for monai.networks.nets.basic_unet

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union

import torch
import torch.nn as nn

from monai.networks.blocks import Convolution, UpSample
from monai.networks.layers.factories import Conv, Pool
from monai.utils import deprecated_arg, ensure_tuple_rep

__all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"]


class TwoConv(nn.Sequential):
    """two convolutions."""

    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        out_chns: int,
        act: Union[str, tuple],
        norm: Union[str, tuple],
        bias: bool,
        dropout: Union[float, tuple] = 0.0,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

        """
        super().__init__()

        conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1)
        conv_1 = Convolution(
            spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1
        )
        self.add_module("conv_0", conv_0)
        self.add_module("conv_1", conv_1)


class Down(nn.Sequential):
    """maxpooling downsampling and two convolutions."""

    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        out_chns: int,
        act: Union[str, tuple],
        norm: Union[str, tuple],
        bias: bool,
        dropout: Union[float, tuple] = 0.0,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

        """
        super().__init__()
        max_pooling = Pool["MAX", spatial_dims](kernel_size=2)
        convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout)
        self.add_module("max_pooling", max_pooling)
        self.add_module("convs", convs)


class UpCat(nn.Module):
    """upsampling, concatenation with the encoder feature map, two convolutions"""

    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        cat_chns: int,
        out_chns: int,
        act: Union[str, tuple],
        norm: Union[str, tuple],
        bias: bool,
        dropout: Union[float, tuple] = 0.0,
        upsample: str = "deconv",
        pre_conv: Optional[Union[nn.Module, str]] = "default",
        interp_mode: str = "linear",
        align_corners: Optional[bool] = True,
        halves: bool = True,
        is_pad: bool = True,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels to be upsampled.
            cat_chns: number of channels from the encoder.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
            pre_conv: a conv block applied before upsampling.
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            align_corners: set the align_corners parameter for upsample. Defaults to True.
                Only used in the "nontrainable" mode.
            halves: whether to halve the number of channels during upsampling.
                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
            is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.

        """
        super().__init__()
        if upsample == "nontrainable" and pre_conv is None:
            up_chns = in_chns
        else:
            up_chns = in_chns // 2 if halves else in_chns
        self.upsample = UpSample(
            spatial_dims,
            in_chns,
            up_chns,
            2,
            mode=upsample,
            pre_conv=pre_conv,
            interp_mode=interp_mode,
            align_corners=align_corners,
        )
        self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout)
        self.is_pad = is_pad

    def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
        """

        Args:
            x: features to be upsampled.
            x_e: features from the encoder.
        """
        x_0 = self.upsample(x)

        if x_e is not None:
            if self.is_pad:
                # handling spatial shapes due to the 2x maxpooling with odd edge lengths.
                dimensions = len(x.shape) - 2
                sp = [0] * (dimensions * 2)
                for i in range(dimensions):
                    if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
                        sp[i * 2 + 1] = 1
                x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
            x = self.convs(torch.cat([x_e, x_0], dim=1))  # input channels: (cat_chns + up_chns)
        else:
            x = self.convs(x_0)

        return x


[docs]class BasicUNet(nn.Module):
[docs] @deprecated_arg( name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." ) def __init__( self, spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 2, features: Sequence[int] = (32, 32, 64, 128, 256, 32), act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), norm: Union[str, tuple] = ("instance", {"affine": True}), bias: bool = True, dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", dimensions: Optional[int] = None, ): """ A UNet implementation with 1D/2D/3D supports. Based on: Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and Morphometry". Nature Methods 16, 67–70 (2019), DOI: http://dx.doi.org/10.1038/s41592-018-0261-2 Args: spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. in_channels: number of input channels. Defaults to 1. out_channels: number of output channels. Defaults to 2. features: six integers as numbers of features. Defaults to ``(32, 32, 64, 128, 256, 32)``, - the first five values correspond to the five-level encoder feature sizes. - the last value corresponds to the feature size after the last upsampling. act: activation type and arguments. Defaults to LeakyReLU. norm: feature normalization type and arguments. Defaults to instance norm. bias: whether to have a bias term in convolution blocks. Defaults to True. According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_, if a conv layer is directly followed by a batch norm layer, bias should be False. dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. .. deprecated:: 0.6.0 ``dimensions`` is deprecated, use ``spatial_dims`` instead. Examples:: # for spatial 2D >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128)) # for spatial 2D, with group norm >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) # for spatial 3D >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32)) See Also - :py:class:`monai.networks.nets.DynUNet` - :py:class:`monai.networks.nets.UNet` """ super().__init__() if dimensions is not None: spatial_dims = dimensions fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout) self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout) self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout) self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout) self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout) self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1)
[docs] def forward(self, x: torch.Tensor): """ Args: x: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`. It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have even edge lengths. Returns: A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``. """ x0 = self.conv_0(x) x1 = self.down_1(x0) x2 = self.down_2(x1) x3 = self.down_3(x2) x4 = self.down_4(x3) u4 = self.upcat_4(x4, x3) u3 = self.upcat_3(u4, x2) u2 = self.upcat_2(u3, x1) u1 = self.upcat_1(u2, x0) logits = self.final_conv(u1) return logits
BasicUnet = Basicunet = basicunet = BasicUNet