Source code for monai.networks.nets.unet

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Norm, Act
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import export, alias


[docs]@export("monai.networks.nets") @alias("Unet") class UNet(nn.Module): def __init__( self, dimensions, in_channels, out_channels, channels, strides, kernel_size=3, up_kernel_size=3, num_res_units=0, act=Act.PRELU, norm=Norm.INSTANCE, dropout=0, ): super().__init__() self.dimensions = dimensions self.in_channels = in_channels self.out_channels = out_channels self.channels = channels self.strides = strides self.kernel_size = kernel_size self.up_kernel_size = up_kernel_size self.num_res_units = num_res_units self.act = act self.norm = norm self.dropout = dropout def _create_block(inc, outc, channels, strides, is_top): """ Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential blocks containing the downsample path, a skip connection around the previous block, and the upsample path. """ c = channels[0] s = strides[0] if len(channels) > 2: subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down upc = c * 2 else: # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer subblock = self._get_bottom_layer(c, channels[1]) upc = c + channels[1] down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path return nn.Sequential(down, SkipConnection(subblock), up) self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) def _get_down_layer(self, in_channels, out_channels, strides, is_top): if self.num_res_units > 0: return ResidualUnit( self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.num_res_units, self.act, self.norm, self.dropout, ) else: return Convolution( self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.act, self.norm, self.dropout ) def _get_bottom_layer(self, in_channels, out_channels): return self._get_down_layer(in_channels, out_channels, 1, False) def _get_up_layer(self, in_channels, out_channels, strides, is_top): conv = Convolution( self.dimensions, in_channels, out_channels, strides, self.up_kernel_size, self.act, self.norm, self.dropout, conv_only=is_top and self.num_res_units == 0, is_transposed=True, ) if self.num_res_units > 0: ru = ResidualUnit( self.dimensions, out_channels, out_channels, 1, self.kernel_size, 1, self.act, self.norm, self.dropout, last_conv_only=is_top, ) return nn.Sequential(conv, ru) else: return conv
[docs] def forward(self, x): x = self.model(x) return x
Unet = unet = UNet