Source code for monai.networks.nets.varautoencoder

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from monai.networks.layers.convutils import calculate_out_shape, same_padding
from monai.networks.layers.factories import Act, Norm
from monai.networks.nets import AutoEncoder


[docs]class VarAutoEncoder(AutoEncoder): """Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114""" def __init__( self, dimensions: int, in_shape: Sequence[int], out_channels: int, latent_size: int, channels: Sequence[int], strides: Sequence[int], kernel_size: Union[Sequence[int], int] = 3, up_kernel_size: Union[Sequence[int], int] = 3, num_res_units: int = 0, inter_channels: Optional[list] = None, inter_dilations: Optional[list] = None, num_inter_units: int = 2, act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, ) -> None: self.in_channels, *self.in_shape = in_shape self.latent_size = latent_size self.final_size = np.asarray(self.in_shape, np.int) super().__init__( dimensions, self.in_channels, out_channels, channels, strides, kernel_size, up_kernel_size, num_res_units, inter_channels, inter_dilations, num_inter_units, act, norm, dropout, ) padding = same_padding(self.kernel_size) for s in strides: self.final_size = calculate_out_shape(self.final_size, self.kernel_size, s, padding) linear_size = int(np.product(self.final_size)) * self.encoded_channels self.mu = nn.Linear(linear_size, self.latent_size) self.logvar = nn.Linear(linear_size, self.latent_size) self.decodeL = nn.Linear(self.latent_size, linear_size) def encode_forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x = self.encode(x) x = self.intermediate(x) x = x.view(x.shape[0], -1) mu = self.mu(x) logvar = self.logvar(x) return mu, logvar def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Tensor: x = F.relu(self.decodeL(z)) x = x.view(x.shape[0], self.channels[-1], *self.final_size) x = self.decode(x) if use_sigmoid: x = torch.sigmoid(x) return x def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: std = torch.exp(0.5 * logvar) if self.training: # multiply random noise with std only during training std = torch.randn_like(std).mul(std) return std.add_(mu)
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) z = self.reparameterize(mu, logvar) return self.decode_forward(z), mu, logvar, z