Source code for monai.networks.nets.generator

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from import Sequence

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import Reshape
from monai.utils import ensure_tuple, ensure_tuple_rep

[docs] class Generator(nn.Module): """ Defines a simple generator network accepting a latent vector and through a sequence of convolution layers constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to create each of these layers, override this method to define layers beyond the default :py:class:`monai.networks.blocks.Convolution` or :py:class:`monai.networks.blocks.ResidualUnit` layers. The layers are constructed using the values in the `channels` and `strides` arguments, the number being defined by the length of these (which must match). Input is first passed through a :py:class:`torch.nn.Linear` layer to convert the input vector to an image tensor with dimensions `start_shape`. This passes through the convolution layers and is progressively upsampled if the `strides` values are greater than 1 using transpose convolutions. The size of the final output is defined by the `start_shape` dimension and the amount of upsampling done through strides. In the default definition the size of the output's spatial dimensions will be that of `start_shape` multiplied by the product of `strides`, thus the example network below upsamples an starting size of (64, 8, 8) to (1, 64, 64) since its `strides` are (2, 2, 2). Args: latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension) start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork channels: tuple of integers stating the output channels of each convolutional layer strides: tuple of integers stating the stride (upscale factor) of each convolutional layer kernel_size: integer or tuple of integers stating size of convolutional kernels num_res_units: integer stating number of convolutions in residual units, 0 means no residual units act: name or type defining activation layers norm: name or type defining normalization layers dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout bias: boolean stating if convolution layers should have a bias component Examples:: # 3 layers, latent input vector of shape (42, 24), output volume of shape (1, 64, 64) net = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2)) """ def __init__( self, latent_shape: Sequence[int], start_shape: Sequence[int], channels: Sequence[int], strides: Sequence[int], kernel_size: Sequence[int] | int = 3, num_res_units: int = 2, act=Act.PRELU, norm=Norm.INSTANCE, dropout: float | None = None, bias: bool = True, ) -> None: super().__init__() self.in_channels, *self.start_shape = ensure_tuple(start_shape) self.dimensions = len(self.start_shape) self.latent_shape = ensure_tuple(latent_shape) self.channels = ensure_tuple(channels) self.strides = ensure_tuple(strides) self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions) self.num_res_units = num_res_units self.act = act self.norm = norm self.dropout = dropout self.bias = bias self.flatten = nn.Flatten() self.linear = nn.Linear(int(, int( self.reshape = Reshape(*start_shape) self.conv = nn.Sequential() echannel = self.in_channels # transform tensor of shape `start_shape' into output shape through transposed convolutions and residual units for i, (c, s) in enumerate(zip(channels, strides)): is_last = i == len(channels) - 1 layer = self._get_layer(echannel, c, s, is_last) self.conv.add_module("layer_%i" % i, layer) echannel = c def _get_layer( self, in_channels: int, out_channels: int, strides: int, is_last: bool ) -> Convolution | nn.Sequential: """ Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels` number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last` is True this is the final layer and is not expected to include activation and normalization layers. """ layer: Convolution | nn.Sequential layer = Convolution( in_channels=in_channels, strides=strides, is_transposed=True, conv_only=is_last or self.num_res_units > 0, spatial_dims=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, norm=self.norm, dropout=self.dropout, bias=self.bias, ) if self.num_res_units > 0: ru = ResidualUnit( in_channels=out_channels, subunits=self.num_res_units, last_conv_only=is_last, spatial_dims=self.dimensions, out_channels=out_channels, kernel_size=self.kernel_size, act=self.act, norm=self.norm, dropout=self.dropout, bias=self.bias, ) layer = nn.Sequential(layer, ru) return layer
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.flatten(x) x = self.linear(x) x = self.reshape(x) x = self.conv(x) return x