Source code for monai.networks.nets.generator

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import Reshape
from monai.utils import ensure_tuple, ensure_tuple_rep


[docs]class Generator(nn.Module): """ Defines a simple generator network accepting a latent vector and through a sequence of convolution layers constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to create each of these layers, override this method to define layers beyond the default Convolution or ResidualUnit layers. For example, a generator accepting a latent vector if shape (42,24) and producing an output volume of shape (1,64,64) can be constructed as: gen = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2)) """ def __init__( self, latent_shape: Sequence[int], start_shape: Sequence[int], channels: Sequence[int], strides: Sequence[int], kernel_size: Union[Sequence[int], int] = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, dropout: Optional[float] = None, bias: bool = True, ) -> None: """ Construct the generator network with the number of layers defined by `channels` and `strides`. In the forward pass a `nn.Linear` layer relates the input latent vector to a tensor of dimensions `start_shape`, this is then fed forward through the sequence of convolutional layers. The number of layers is defined by the length of `channels` and `strides` which must match, each layer having the number of output channels given in `channels` and an upsample factor given in `strides` (ie. a transpose convolution with that stride size). Args: latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension) start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork channels: tuple of integers stating the output channels of each convolutional layer strides: tuple of integers stating the stride (upscale factor) of each convolutional layer kernel_size: integer or tuple of integers stating size of convolutional kernels num_res_units: integer stating number of convolutions in residual units, 0 means no residual units act: name or type defining activation layers norm: name or type defining normalization layers dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout bias: boolean stating if convolution layers should have a bias component """ super().__init__() self.in_channels, *self.start_shape = ensure_tuple(start_shape) self.dimensions = len(self.start_shape) self.latent_shape = ensure_tuple(latent_shape) self.channels = ensure_tuple(channels) self.strides = ensure_tuple(strides) self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions) self.num_res_units = num_res_units self.act = act self.norm = norm self.dropout = dropout self.bias = bias self.flatten = nn.Flatten() self.linear = nn.Linear(int(np.prod(self.latent_shape)), int(np.prod(start_shape))) self.reshape = Reshape(*start_shape) self.conv = nn.Sequential() echannel = self.in_channels # transform tensor of shape `start_shape' into output shape through transposed convolutions and residual units for i, (c, s) in enumerate(zip(channels, strides)): is_last = i == len(channels) - 1 layer = self._get_layer(echannel, c, s, is_last) self.conv.add_module("layer_%i" % i, layer) echannel = c def _get_layer( self, in_channels: int, out_channels: int, strides: int, is_last: bool ) -> Union[Convolution, nn.Sequential]: """ Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels` number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last` is True this is the final layer and is not expected to include activation and normalization layers. """ layer: Union[Convolution, nn.Sequential] layer = Convolution( in_channels=in_channels, strides=strides, is_transposed=True, conv_only=is_last or self.num_res_units > 0, dimensions=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, norm=self.norm, dropout=self.dropout, bias=self.bias, ) if self.num_res_units > 0: ru = ResidualUnit( in_channels=out_channels, subunits=self.num_res_units, last_conv_only=is_last, dimensions=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, norm=self.norm, dropout=self.dropout, bias=self.bias, ) layer = nn.Sequential(layer, ru) return layer
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.flatten(x) x = self.linear(x) x = self.reshape(x) x = self.conv(x) return x