Source code for monai.networks.nets.segresnet

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.segresnet_block import ResBlock, get_conv_layer, get_norm_layer, get_upsample_layer
from monai.networks.layers.factories import Act, Dropout
from monai.utils import UpsampleMode


[docs]class SegResNet(nn.Module): """ SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization <https://arxiv.org/pdf/1810.11654.pdf>`_. The module does not include the variational autoencoder (VAE). The model supports 2D or 3D inputs. Args: spatial_dims: spatial dimension of the input data. Defaults to 3. init_filters: number of output channels for initial convolution layer. Defaults to 8. in_channels: number of input channels for the network. Defaults to 1. out_channels: number of output channels for the network. Defaults to 2. dropout_prob: probability of an element to be zero-ed. Defaults to ``None``. norm_name: feature normalization type, this module only supports group norm, batch norm and instance norm. Defaults to ``group``. num_groups: number of groups to separate the channels into. Defaults to 8. use_conv_final: if add a final convolution block to output. Defaults to ``True``. blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``. blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``. upsample_mode: [``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``] The mode of upsampling manipulations. Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``. - ``deconv``, uses transposed convolution layers. - ``nontrainable``, uses non-trainable `linear` interpolation. - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`. """ def __init__( self, spatial_dims: int = 3, init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, dropout_prob: Optional[float] = None, norm_name: str = "group", num_groups: int = 8, use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, ): super().__init__() if spatial_dims not in (2, 3): raise AssertionError("spatial_dims can only be 2 or 3.") self.spatial_dims = spatial_dims self.init_filters = init_filters self.blocks_down = blocks_down self.blocks_up = blocks_up self.dropout_prob = dropout_prob self.norm_name = norm_name self.num_groups = num_groups self.upsample_mode = UpsampleMode(upsample_mode) self.use_conv_final = use_conv_final self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters) self.down_layers = self._make_down_layers() self.up_layers, self.up_samples = self._make_up_layers() self.relu = Act[Act.RELU](inplace=True) self.conv_final = self._make_final_conv(out_channels) if dropout_prob is not None: self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob) def _make_down_layers(self): down_layers = nn.ModuleList() blocks_down, spatial_dims, filters, norm_name, num_groups = ( self.blocks_down, self.spatial_dims, self.init_filters, self.norm_name, self.num_groups, ) for i in range(len(blocks_down)): layer_in_channels = filters * 2 ** i pre_conv = ( get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) if i > 0 else nn.Identity() ) down_layer = nn.Sequential( pre_conv, *[ ResBlock(spatial_dims, layer_in_channels, norm_name=norm_name, num_groups=num_groups) for _ in range(blocks_down[i]) ], ) down_layers.append(down_layer) return down_layers def _make_up_layers(self): up_layers, up_samples = nn.ModuleList(), nn.ModuleList() upsample_mode, blocks_up, spatial_dims, filters, norm_name, num_groups = ( self.upsample_mode, self.blocks_up, self.spatial_dims, self.init_filters, self.norm_name, self.num_groups, ) n_up = len(blocks_up) for i in range(n_up): sample_in_channels = filters * 2 ** (n_up - i) up_layers.append( nn.Sequential( *[ ResBlock(spatial_dims, sample_in_channels // 2, norm_name=norm_name, num_groups=num_groups) for _ in range(blocks_up[i]) ] ) ) up_samples.append( nn.Sequential( *[ get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1), get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode), ] ) ) return up_layers, up_samples def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(self.spatial_dims, self.init_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, self.init_filters, out_channels=out_channels, kernel_size=1, bias=True), )
[docs] def forward(self, x): x = self.convInit(x) if self.dropout_prob is not None: x = self.dropout(x) down_x = [] for down in self.down_layers: x = down(x) down_x.append(x) down_x.reverse() for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): x = up(x) + down_x[i + 1] x = upl(x) if self.use_conv_final: x = self.conv_final(x) return x
[docs]class SegResNetVAE(SegResNet): """ SegResNetVAE based on `3D MRI brain tumor segmentation using autoencoder regularization <https://arxiv.org/pdf/1810.11654.pdf>`_. The module contains the variational autoencoder (VAE). The model supports 2D or 3D inputs. Args: spatial_dims: spatial dimension of the input data. Defaults to 3. init_filters: number of output channels for initial convolution layer. Defaults to 8. in_channels: number of input channels for the network. Defaults to 1. out_channels: number of output channels for the network. Defaults to 2. dropout_prob: probability of an element to be zero-ed. Defaults to ``None``. norm_name: feature normalization type, this module only supports group norm, batch norm and instance norm. Defaults to ``group``. num_groups: number of groups to separate the channels into. Defaults to 8. use_conv_final: if add a final convolution block to output. Defaults to ``True``. blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``. blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``. upsample_mode: [``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``] The mode of upsampling manipulations. Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to `nontrainable`. - ``deconv``, uses transposed convolution layers. - ``nontrainable``, uses non-trainable `linear` interpolation. - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`. use_vae: if use the variational autoencoder (VAE) during training. Defaults to ``False``. input_image_size: the size of images to input into the network. It is used to determine the in_features of the fc layer in VAE. When ``use_vae == True``, please ensure that this parameter is set. Defaults to ``None``. vae_estimate_std: whether to estimate the standard deviations in VAE. Defaults to ``False``. vae_default_std: if not to estimate the std, use the default value. Defaults to 0.3. vae_nz: number of latent variables in VAE. Defaults to 256. Where, 128 to represent mean, and 128 to represent std. """ def __init__( self, input_image_size: Sequence[int], vae_estimate_std: bool = False, vae_default_std: float = 0.3, vae_nz: int = 256, spatial_dims: int = 3, init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, dropout_prob: Optional[float] = None, norm_name: str = "group", num_groups: int = 8, use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = "nontrainable", ): super(SegResNetVAE, self).__init__( spatial_dims=spatial_dims, init_filters=init_filters, in_channels=in_channels, out_channels=out_channels, dropout_prob=dropout_prob, norm_name=norm_name, num_groups=num_groups, use_conv_final=use_conv_final, blocks_down=blocks_down, blocks_up=blocks_up, upsample_mode=upsample_mode, ) self.input_image_size = input_image_size self.smallest_filters = 16 zoom = 2 ** (len(self.blocks_down) - 1) self.fc_insize = [s // (2 * zoom) for s in self.input_image_size] self.vae_estimate_std = vae_estimate_std self.vae_default_std = vae_default_std self.vae_nz = vae_nz self._prepare_vae_modules() self.vae_conv_final = self._make_final_conv(in_channels) def _prepare_vae_modules(self): zoom = 2 ** (len(self.blocks_down) - 1) v_filters = self.init_filters * zoom total_elements = int(self.smallest_filters * np.prod(self.fc_insize)) self.vae_down = nn.Sequential( get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True), get_norm_layer( self.spatial_dims, self.smallest_filters, norm_name=self.norm_name, num_groups=self.num_groups ), self.relu, ) self.vae_fc1 = nn.Linear(total_elements, self.vae_nz) self.vae_fc2 = nn.Linear(total_elements, self.vae_nz) self.vae_fc3 = nn.Linear(self.vae_nz, total_elements) self.vae_fc_up_sample = nn.Sequential( get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1), get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode), get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, ) def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): """ Args: net_input: the original input of the network. vae_input: the input of VAE module, which is also the output of the network's encoder. """ x_vae = self.vae_down(vae_input) x_vae = x_vae.view(-1, self.vae_fc1.in_features) z_mean = self.vae_fc1(x_vae) z_mean_rand = torch.randn_like(z_mean) z_mean_rand.requires_grad_(False) if self.vae_estimate_std: z_sigma = self.vae_fc2(x_vae) z_sigma = F.softplus(z_sigma) vae_reg_loss = 0.5 * torch.mean(z_mean ** 2 + z_sigma ** 2 - torch.log(1e-8 + z_sigma ** 2) - 1) x_vae = z_mean + z_sigma * z_mean_rand else: z_sigma = self.vae_default_std vae_reg_loss = torch.mean(z_mean ** 2) x_vae = z_mean + z_sigma * z_mean_rand x_vae = self.vae_fc3(x_vae) x_vae = self.relu(x_vae) x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize) x_vae = self.vae_fc_up_sample(x_vae) for up, upl in zip(self.up_samples, self.up_layers): x_vae = up(x_vae) x_vae = upl(x_vae) x_vae = self.vae_conv_final(x_vae) vae_mse_loss = F.mse_loss(net_input, x_vae) vae_loss = vae_reg_loss + vae_mse_loss return vae_loss
[docs] def forward(self, x): net_input = x x = self.convInit(x) if self.dropout_prob is not None: x = self.dropout(x) down_x = [] for down in self.down_layers: x = down(x) down_x.append(x) down_x.reverse() vae_input = x for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): x = up(x) + down_x[i + 1] x = upl(x) if self.use_conv_final: x = self.conv_final(x) if self.training: vae_loss = self._get_vae_loss(net_input, vae_input) return x, vae_loss return x, None