Source code for monai.networks.nets.varautoencoder

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from monai.networks.layers.convutils import calculate_out_shape, same_padding
from monai.networks.layers.factories import Act, Norm
from monai.networks.nets import AutoEncoder

__all__ = ["VarAutoEncoder"]


[docs]class VarAutoEncoder(AutoEncoder): """ Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114 Args: spatial_dims: number of spatial dimensions. in_shape: shape of input data starting with channel dimension. out_channels: number of output channels. latent_size: size of the latent variable. channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. kernel_size: convolution kernel size, the value(s) should be odd. If sequence, its length should equal to dimensions. Defaults to 3. up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, its length should equal to dimensions. Defaults to 3. num_res_units: number of residual units. Defaults to 0. inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode. inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1. num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0. act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. bias: whether to have a bias term in convolution blocks. Defaults to True. According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_, if a conv layer is directly followed by a batch norm layer, bias should be False. use_sigmoid: whether to use the sigmoid function on final output. Defaults to True. Examples:: from monai.networks.nets import VarAutoEncoder # 3 layer network accepting images with dimensions (1, 32, 32) and using a latent vector with 2 values model = VarAutoEncoder( spatial_dims=2, in_shape=(32, 32), # image spatial shape out_channels=1, latent_size=2, channels=(16, 32, 64), strides=(1, 2, 2), ) see also: - Variational autoencoder network with MedNIST Dataset https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb """ def __init__( self, spatial_dims: int, in_shape: Sequence[int], out_channels: int, latent_size: int, channels: Sequence[int], strides: Sequence[int], kernel_size: Sequence[int] | int = 3, up_kernel_size: Sequence[int] | int = 3, num_res_units: int = 0, inter_channels: list | None = None, inter_dilations: list | None = None, num_inter_units: int = 2, act: tuple | str | None = Act.PRELU, norm: tuple | str = Norm.INSTANCE, dropout: tuple | str | float | None = None, bias: bool = True, use_sigmoid: bool = True, ) -> None: self.in_channels, *self.in_shape = in_shape self.use_sigmoid = use_sigmoid self.latent_size = latent_size self.final_size = np.asarray(self.in_shape, dtype=int) super().__init__( spatial_dims, self.in_channels, out_channels, channels, strides, kernel_size, up_kernel_size, num_res_units, inter_channels, inter_dilations, num_inter_units, act, norm, dropout, bias, ) padding = same_padding(self.kernel_size) for s in strides: self.final_size = calculate_out_shape(self.final_size, self.kernel_size, s, padding) # type: ignore linear_size = int(np.product(self.final_size)) * self.encoded_channels self.mu = nn.Linear(linear_size, self.latent_size) self.logvar = nn.Linear(linear_size, self.latent_size) self.decodeL = nn.Linear(self.latent_size, linear_size) def encode_forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = self.encode(x) x = self.intermediate(x) x = x.view(x.shape[0], -1) mu = self.mu(x) logvar = self.logvar(x) return mu, logvar def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Tensor: x = F.relu(self.decodeL(z)) x = x.view(x.shape[0], self.channels[-1], *self.final_size) x = self.decode(x) if use_sigmoid: x = torch.sigmoid(x) return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: std = torch.exp(0.5 * logvar) if self.training: # multiply random noise with std only during training std = torch.randn_like(std).mul(std) return std.add_(mu)
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) z = self.reparameterize(mu, logvar) return self.decode_forward(z, self.use_sigmoid), mu, logvar, z