# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Norm, Act
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import export
from monai.utils.aliases import alias
[docs]@export("monai.networks.nets")
@alias("Unet")
class UNet(nn.Module):
def __init__(self, dimensions, in_channels, out_channels, channels, strides, kernel_size=3, up_kernel_size=3,
num_res_units=0, act=Act.PRELU, norm=Norm.INSTANCE, dropout=0):
super().__init__()
assert len(channels) == (len(strides) + 1)
self.dimensions = dimensions
self.in_channels = in_channels
self.out_channels = out_channels
self.channels = channels
self.strides = strides
self.kernel_size = kernel_size
self.up_kernel_size = up_kernel_size
self.num_res_units = num_res_units
self.act = act
self.norm = norm
self.dropout = dropout
def _create_block(inc, outc, channels, strides, is_top):
"""
Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential
blocks containing the downsample path, a skip connection around the previous block, and the upsample path.
"""
c = channels[0]
s = strides[0]
if len(channels) > 2:
subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down
upc = c * 2
else:
# the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer
subblock = self._get_bottom_layer(c, channels[1])
upc = c + channels[1]
down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path
up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path
return nn.Sequential(down, SkipConnection(subblock), up)
self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True)
def _get_down_layer(self, in_channels, out_channels, strides, is_top):
if self.num_res_units > 0:
return ResidualUnit(self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.num_res_units,
self.act, self.norm, self.dropout)
else:
return Convolution(self.dimensions, in_channels, out_channels, strides, self.kernel_size, self.act, self.norm,
self.dropout)
def _get_bottom_layer(self, in_channels, out_channels):
return self._get_down_layer(in_channels, out_channels, 1, False)
def _get_up_layer(self, in_channels, out_channels, strides, is_top):
conv = Convolution(self.dimensions, in_channels, out_channels, strides, self.up_kernel_size, self.act, self.norm,
self.dropout, conv_only=is_top and self.num_res_units == 0, is_transposed=True)
if self.num_res_units > 0:
ru = ResidualUnit(self.dimensions, out_channels, out_channels, 1, self.kernel_size, 1, self.act, self.norm,
self.dropout, last_conv_only=is_top)
return nn.Sequential(conv, ru)
else:
return conv
[docs] def forward(self, x):
x = self.model(x)
return x
Unet = unet = UNet