Source code for monai.networks.nets.transchex

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
import os
import shutil
import tarfile
import tempfile
from collections.abc import Sequence

import torch
from torch import nn

from monai.utils import optional_import

transformers = optional_import("transformers")
load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0]
cached_path = optional_import("transformers.file_utils", name="cached_path")[0]
BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0]
BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0]

__all__ = ["BertPreTrainedModel", "BertAttention", "BertOutput", "BertMixedLayer", "Pooler", "MultiModal", "Transchex"]


class BertPreTrainedModel(nn.Module):
    """Module to load BERT pre-trained weights.
    Based on:
    LXMERT
    https://github.com/airsplay/lxmert
    BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, *inputs, **kwargs) -> None:
        super().__init__()

    def init_bert_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, torch.nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @classmethod
    def from_pretrained(
        cls,
        num_language_layers,
        num_vision_layers,
        num_mixed_layers,
        bert_config,
        state_dict=None,
        cache_dir=None,
        from_tf=False,
        *inputs,
        **kwargs,
    ):
        archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
        resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        tempdir = None
        if os.path.isdir(resolved_archive_file) or from_tf:
            serialization_dir = resolved_archive_file
        else:
            tempdir = tempfile.mkdtemp()
            with tarfile.open(resolved_archive_file, "r:gz") as archive:

                def is_within_directory(directory, target):
                    abs_directory = os.path.abspath(directory)
                    abs_target = os.path.abspath(target)

                    prefix = os.path.commonprefix([abs_directory, abs_target])

                    return prefix == abs_directory

                def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
                    for member in tar.getmembers():
                        member_path = os.path.join(path, member.name)
                        if not is_within_directory(path, member_path):
                            raise Exception("Attempted Path Traversal in Tar File")

                    tar.extractall(path, members, numeric_owner=numeric_owner)

                safe_extract(archive, tempdir)
            serialization_dir = tempdir
        model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
        if state_dict is None and not from_tf:
            weights_path = os.path.join(serialization_dir, "pytorch_model.bin")
            state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
        if tempdir:
            shutil.rmtree(tempdir)
        if from_tf:
            weights_path = os.path.join(serialization_dir, "model.ckpt")
            return load_tf_weights_in_bert(model, weights_path)
        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if "gamma" in key:
                new_key = key.replace("gamma", "weight")
            if "beta" in key:
                new_key = key.replace("beta", "bias")
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)
        missing_keys: list = []
        unexpected_keys: list = []
        error_msgs: list = []
        metadata = getattr(state_dict, "_metadata", None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=""):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
            )
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + ".")

        start_prefix = ""
        if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()):
            start_prefix = "bert."
        load(model, prefix=start_prefix)
        return model


class BertAttention(nn.Module):
    """BERT attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, context):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(context)
        mixed_value_layer = self.value(context)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores))
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer


class BertOutput(nn.Module):
    """BERT output layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertMixedLayer(nn.Module):
    """BERT cross attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.att_x = BertAttention(config)
        self.output_x = BertOutput(config)
        self.att_y = BertAttention(config)
        self.output_y = BertOutput(config)

    def forward(self, x, y):
        output_x = self.att_x(x, y)
        output_y = self.att_y(y, x)
        return self.output_x(output_x, x), self.output_y(output_y, y)


class Pooler(nn.Module):
    """BERT pooler layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, hidden_size) -> None:
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class MultiModal(BertPreTrainedModel):
    """
    Multimodal Transformers From Pretrained BERT Weights"
    """

    def __init__(
        self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict
    ) -> None:
        """
        Args:
            num_language_layers: number of language transformer layers.
            num_vision_layers: number of vision transformer layers.
            bert_config: configuration for bert language transformer encoder.

        """
        super().__init__()
        self.config = type("obj", (object,), bert_config)
        self.embeddings = BertEmbeddings(self.config)
        self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)])
        self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)])
        self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)])
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None):
        language_features = self.embeddings(input_ids, token_type_ids)
        for layer in self.vision_encoder:
            vision_feats = layer(vision_feats, None)[0]
        for layer in self.language_encoder:
            language_features = layer(language_features, attention_mask)[0]
        for layer in self.mixed_encoder:
            language_features, vision_feats = layer(language_features, vision_feats)
        return language_features, vision_feats


[docs] class Transchex(torch.nn.Module): """ TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language Transformers for Chest X-ray Analysis" """
[docs] def __init__( self, in_channels: int, img_size: Sequence[int] | int, patch_size: int | tuple[int, int], num_classes: int, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, hidden_size: int = 768, drop_out: float = 0.0, attention_probs_dropout_prob: float = 0.1, gradient_checkpointing: bool = False, hidden_act: str = "gelu", hidden_dropout_prob: float = 0.1, initializer_range: float = 0.02, intermediate_size: int = 3072, layer_norm_eps: float = 1e-12, max_position_embeddings: int = 512, model_type: str = "bert", num_attention_heads: int = 12, num_hidden_layers: int = 12, pad_token_id: int = 0, position_embedding_type: str = "absolute", transformers_version: str = "4.10.2", type_vocab_size: int = 2, use_cache: bool = True, vocab_size: int = 30522, chunk_size_feed_forward: int = 0, is_decoder: bool = False, add_cross_attention: bool = False, ) -> None: """ Args: in_channels: dimension of input channels. img_size: dimension of input image. patch_size: dimension of patch size. num_classes: number of classes if classification is used. num_language_layers: number of language transformer layers. num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: fraction of the input units to drop. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. Examples: .. code-block:: python # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head net = Transchex(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) """ super().__init__() bert_config = { "attention_probs_dropout_prob": attention_probs_dropout_prob, "classifier_dropout": None, "gradient_checkpointing": gradient_checkpointing, "hidden_act": hidden_act, "hidden_dropout_prob": hidden_dropout_prob, "hidden_size": hidden_size, "initializer_range": initializer_range, "intermediate_size": intermediate_size, "layer_norm_eps": layer_norm_eps, "max_position_embeddings": max_position_embeddings, "model_type": model_type, "num_attention_heads": num_attention_heads, "num_hidden_layers": num_hidden_layers, "pad_token_id": pad_token_id, "position_embedding_type": position_embedding_type, "transformers_version": transformers_version, "type_vocab_size": type_vocab_size, "use_cache": use_cache, "vocab_size": vocab_size, "chunk_size_feed_forward": chunk_size_feed_forward, "is_decoder": is_decoder, "add_cross_attention": add_cross_attention, } if not (0 <= drop_out <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore raise ValueError("img_size should be divisible by patch_size.") self.multimodal = MultiModal.from_pretrained( num_language_layers=num_language_layers, num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, bert_config=bert_config, ) self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size ) self.norm_vision_pos = nn.LayerNorm(hidden_size) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) self.pooler = Pooler(hidden_size=hidden_size) self.drop = torch.nn.Dropout(drop_out) self.cls_head = torch.nn.Linear(hidden_size, num_classes)
[docs] def forward(self, input_ids, token_type_ids=None, vision_feats=None): attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) attention_mask = (1.0 - attention_mask) * -10000.0 vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) vision_feats = self.norm_vision_pos(vision_feats) vision_feats = vision_feats + self.pos_embed_vis hidden_state_lang, hidden_state_vis = self.multimodal( input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask ) pooled_features = self.pooler(hidden_state_lang) logits = self.cls_head(self.drop(pooled_features)) return logits