Source code for monai.networks.nets.vitautoenc

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
from collections.abc import Sequence

import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import Conv
from monai.utils import deprecated_arg, ensure_tuple_rep, is_sqrt

__all__ = ["ViTAutoEnc"]


[docs] class ViTAutoEnc(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" Modified to also give same dimension outputs as the input size of the image """
[docs] @deprecated_arg( name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead." ) def __init__( self, in_channels: int, img_size: Sequence[int] | int, patch_size: Sequence[int] | int, out_channels: int = 1, deconv_chns: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", proj_type: str = "conv", dropout_rate: float = 0.0, spatial_dims: int = 3, qkv_bias: bool = False, save_attn: bool = False, ) -> None: """ Args: in_channels: dimension of input channels or the number of channels for input. img_size: dimension of input image. patch_size: dimension of patch size out_channels: number of output channels. Defaults to 1. deconv_chns: number of channels for the deconvolution layers. Defaults to 16. hidden_size: dimension of hidden layer. Defaults to 768. mlp_dim: dimension of feedforward layer. Defaults to 3072. num_layers: number of transformer blocks. Defaults to 12. num_heads: number of attention heads. Defaults to 12. proj_type: position embedding layer type. Defaults to "conv". dropout_rate: fraction of the input units to drop. Defaults to 0.0. spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone # It will provide an output of same size as that of the input >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv') # for 3-channel with image size of (128,128,128), output will be same size as of input >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv') """ super().__init__() if not is_sqrt(patch_size): raise ValueError(f"patch_size should be square number, got {patch_size}.") self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) self.img_size = ensure_tuple_rep(img_size, spatial_dims) self.spatial_dims = spatial_dims for m, p in zip(self.img_size, self.patch_size): if m % p != 0: raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, img_size=img_size, patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, proj_type=proj_type, dropout_rate=dropout_rate, spatial_dims=self.spatial_dims, ) self.blocks = nn.ModuleList( [ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) for i in range(num_layers) ] ) self.norm = nn.LayerNorm(hidden_size) conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] # self.conv3d_transpose* is to be compatible with existing 3d model weights. up_kernel_size = [int(math.sqrt(i)) for i in self.patch_size] self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=up_kernel_size, stride=up_kernel_size) self.conv3d_transpose_1 = conv_trans( in_channels=deconv_chns, out_channels=out_channels, kernel_size=up_kernel_size, stride=up_kernel_size )
[docs] def forward(self, x): """ Args: x: input tensor must have isotropic spatial dimensions, such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. """ spatial_size = x.shape[2:] x = self.patch_embedding(x) hidden_states_out = [] for blk in self.blocks: x = blk(x) hidden_states_out.append(x) x = self.norm(x) x = x.transpose(1, 2) d = [s // p for s, p in zip(spatial_size, self.patch_size)] x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) x = self.conv3d_transpose(x) x = self.conv3d_transpose_1(x) return x, hidden_states_out