from typing import Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import Reshape
from monai.utils import ensure_tuple, ensure_tuple_rep

[docs]class Generator(nn.Module):
"""
Defines a simple generator network accepting a latent vector and through a sequence of convolution layers
constructs an output tensor of greater size and high dimensionality. The method _get_layer is used to
create each of these layers, override this method to define layers beyond the default Convolution or
ResidualUnit layers.

For example, a generator accepting a latent vector if shape (42,24) and producing an output volume of
shape (1,64,64) can be constructed as:

gen = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2))
"""

def __init__(
self,
latent_shape: Sequence[int],
start_shape: Sequence[int],
channels: Sequence[int],
strides: Sequence[int],
kernel_size: Union[Sequence[int], int] = 3,
num_res_units: int = 2,
act=Act.PRELU,
norm=Norm.INSTANCE,
dropout: Optional[float] = None,
bias: bool = True,
) -> None:
"""
Construct the generator network with the number of layers defined by channels and strides. In the
forward pass a nn.Linear layer relates the input latent vector to a tensor of dimensions start_shape,
this is then fed forward through the sequence of convolutional layers. The number of layers is defined by
the length of channels and strides which must match, each layer having the number of output channels
given in channels and an upsample factor given in strides (ie. a transpose convolution with that stride
size).

Args:
latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension)
start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (upscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
"""
super().__init__()

self.in_channels, *self.start_shape = ensure_tuple(start_shape)
self.dimensions = len(self.start_shape)

self.latent_shape = ensure_tuple(latent_shape)
self.channels = ensure_tuple(channels)
self.strides = ensure_tuple(strides)
self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions)
self.num_res_units = num_res_units
self.act = act
self.norm = norm
self.dropout = dropout
self.bias = bias

self.flatten = nn.Flatten()
self.linear = nn.Linear(int(np.prod(self.latent_shape)), int(np.prod(start_shape)))
self.reshape = Reshape(*start_shape)
self.conv = nn.Sequential()

echannel = self.in_channels

# transform tensor of shape start_shape' into output shape through transposed convolutions and residual units
for i, (c, s) in enumerate(zip(channels, strides)):
is_last = i == len(channels) - 1
layer = self._get_layer(echannel, c, s, is_last)
echannel = c

def _get_layer(
self, in_channels: int, out_channels: int, strides: int, is_last: bool
) -> Union[Convolution, nn.Sequential]:
"""
Returns a layer accepting inputs with in_channels number of channels and producing outputs of out_channels
number of channels. The strides indicates upsampling factor, ie. transpose convolutional stride. If is_last
is True this is the final layer and is not expected to include activation and normalization layers.
"""

layer: Union[Convolution, nn.Sequential]

layer = Convolution(
in_channels=in_channels,
strides=strides,
is_transposed=True,
conv_only=is_last or self.num_res_units > 0,
dimensions=self.dimensions,
out_channels=out_channels,
kernel_size=self.kernel_size,
act=self.act,
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
)

if self.num_res_units > 0:
ru = ResidualUnit(
in_channels=out_channels,
subunits=self.num_res_units,
last_conv_only=is_last,
dimensions=self.dimensions,
out_channels=out_channels,
kernel_size=self.kernel_size,
act=self.act,
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
)

layer = nn.Sequential(layer, ru)

return layer

[docs]    def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.flatten(x)
x = self.linear(x)
x = self.reshape(x)
x = self.conv(x)
return x
