Source code for monai.networks.nets.transchex

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
import shutil
import tarfile
import tempfile
from typing import List, Sequence, Tuple, Union

import torch
from torch import nn

from monai.utils import optional_import

transformers = optional_import("transformers")
load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0]
cached_path = optional_import("transformers.file_utils", name="cached_path")[0]
BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0]
BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0]

__all__ = ["BertPreTrainedModel", "BertAttention", "BertOutput", "BertMixedLayer", "Pooler", "MultiModal", "Transchex"]


class BertPreTrainedModel(nn.Module):
    """Module to load BERT pre-trained weights.
    Based on:
    LXMERT
    https://github.com/airsplay/lxmert
    BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, *inputs, **kwargs) -> None:
        super().__init__()

    def init_bert_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)  # type: ignore
        elif isinstance(module, torch.nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @classmethod
    def from_pretrained(
        cls,
        num_language_layers,
        num_vision_layers,
        num_mixed_layers,
        bert_config,
        state_dict=None,
        cache_dir=None,
        from_tf=False,
        *inputs,
        **kwargs,
    ):
        archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
        resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        tempdir = None
        if os.path.isdir(resolved_archive_file) or from_tf:
            serialization_dir = resolved_archive_file
        else:
            tempdir = tempfile.mkdtemp()
            with tarfile.open(resolved_archive_file, "r:gz") as archive:

                def is_within_directory(directory, target):

                    abs_directory = os.path.abspath(directory)
                    abs_target = os.path.abspath(target)

                    prefix = os.path.commonprefix([abs_directory, abs_target])

                    return prefix == abs_directory

                def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

                    for member in tar.getmembers():
                        member_path = os.path.join(path, member.name)
                        if not is_within_directory(path, member_path):
                            raise Exception("Attempted Path Traversal in Tar File")

                    tar.extractall(path, members, numeric_owner=numeric_owner)

                safe_extract(archive, tempdir)
            serialization_dir = tempdir
        model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
        if state_dict is None and not from_tf:
            weights_path = os.path.join(serialization_dir, "pytorch_model.bin")
            state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
        if tempdir:
            shutil.rmtree(tempdir)
        if from_tf:
            weights_path = os.path.join(serialization_dir, "model.ckpt")
            return load_tf_weights_in_bert(model, weights_path)
        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if "gamma" in key:
                new_key = key.replace("gamma", "weight")
            if "beta" in key:
                new_key = key.replace("beta", "bias")
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)
        missing_keys: List = []
        unexpected_keys: List = []
        error_msgs: List = []
        metadata = getattr(state_dict, "_metadata", None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=""):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
            )
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + ".")

        start_prefix = ""
        if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()):
            start_prefix = "bert."
        load(model, prefix=start_prefix)
        return model


class BertAttention(nn.Module):
    """BERT attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, context):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(context)
        mixed_value_layer = self.value(context)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores))
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer


class BertOutput(nn.Module):
    """BERT output layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertMixedLayer(nn.Module):
    """BERT cross attention layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, config) -> None:
        super().__init__()
        self.att_x = BertAttention(config)
        self.output_x = BertOutput(config)
        self.att_y = BertAttention(config)
        self.output_y = BertOutput(config)

    def forward(self, x, y):
        output_x = self.att_x(x, y)
        output_y = self.att_y(y, x)
        return self.output_x(output_x, x), self.output_y(output_y, y)


class Pooler(nn.Module):
    """BERT pooler layer.
    Based on: BERT (pytorch-transformer)
    https://github.com/huggingface/transformers
    """

    def __init__(self, hidden_size) -> None:
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class MultiModal(BertPreTrainedModel):
    """
    Multimodal Transformers From Pretrained BERT Weights"
    """

    def __init__(
        self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict
    ) -> None:
        """
        Args:
            num_language_layers: number of language transformer layers.
            num_vision_layers: number of vision transformer layers.
            bert_config: configuration for bert language transformer encoder.

        """
        super().__init__()
        self.config = type("obj", (object,), bert_config)
        self.embeddings = BertEmbeddings(self.config)
        self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)])
        self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)])
        self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)])
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None):
        language_features = self.embeddings(input_ids, token_type_ids)
        for layer in self.vision_encoder:
            vision_feats = layer(vision_feats, None)[0]
        for layer in self.language_encoder:
            language_features = layer(language_features, attention_mask)[0]
        for layer in self.mixed_encoder:
            language_features, vision_feats = layer(language_features, vision_feats)
        return language_features, vision_feats


[docs]class Transchex(torch.nn.Module): """ TransChex based on: "Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language Transformers for Chest X-ray Analysis" """
[docs] def __init__( self, in_channels: int, img_size: Union[Sequence[int], int], patch_size: Union[int, Tuple[int, int]], num_classes: int, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, hidden_size: int = 768, drop_out: float = 0.0, attention_probs_dropout_prob: float = 0.1, gradient_checkpointing: bool = False, hidden_act: str = "gelu", hidden_dropout_prob: float = 0.1, initializer_range: float = 0.02, intermediate_size: int = 3072, layer_norm_eps: float = 1e-12, max_position_embeddings: int = 512, model_type: str = "bert", num_attention_heads: int = 12, num_hidden_layers: int = 12, pad_token_id: int = 0, position_embedding_type: str = "absolute", transformers_version: str = "4.10.2", type_vocab_size: int = 2, use_cache: bool = True, vocab_size: int = 30522, chunk_size_feed_forward: int = 0, is_decoder: bool = False, add_cross_attention: bool = False, ) -> None: """ Args: in_channels: dimension of input channels. img_size: dimension of input image. patch_size: dimension of patch size. num_classes: number of classes if classification is used. num_language_layers: number of language transformer layers. num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: faction of the input units to drop. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. Examples: .. code-block:: python # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers, # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head net = Transchex(in_channels=3, img_size=(224, 224), num_classes=3, num_language_layers=2, num_vision_layers=2, num_mixed_layers=2, drop_out=0.2) """ super().__init__() bert_config = { "attention_probs_dropout_prob": attention_probs_dropout_prob, "classifier_dropout": None, "gradient_checkpointing": gradient_checkpointing, "hidden_act": hidden_act, "hidden_dropout_prob": hidden_dropout_prob, "hidden_size": hidden_size, "initializer_range": initializer_range, "intermediate_size": intermediate_size, "layer_norm_eps": layer_norm_eps, "max_position_embeddings": max_position_embeddings, "model_type": model_type, "num_attention_heads": num_attention_heads, "num_hidden_layers": num_hidden_layers, "pad_token_id": pad_token_id, "position_embedding_type": position_embedding_type, "transformers_version": transformers_version, "type_vocab_size": type_vocab_size, "use_cache": use_cache, "vocab_size": vocab_size, "chunk_size_feed_forward": chunk_size_feed_forward, "is_decoder": is_decoder, "add_cross_attention": add_cross_attention, } if not (0 <= drop_out <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0): # type: ignore raise ValueError("img_size should be divisible by patch_size.") self.multimodal = MultiModal.from_pretrained( num_language_layers=num_language_layers, num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, bert_config=bert_config, ) self.patch_size = patch_size self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1]) # type: ignore self.vision_proj = nn.Conv2d( in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size ) self.norm_vision_pos = nn.LayerNorm(hidden_size) self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size)) self.pooler = Pooler(hidden_size=hidden_size) self.drop = torch.nn.Dropout(drop_out) self.cls_head = torch.nn.Linear(hidden_size, num_classes)
[docs] def forward(self, input_ids, token_type_ids=None, vision_feats=None): attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) attention_mask = (1.0 - attention_mask) * -10000.0 vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2) vision_feats = self.norm_vision_pos(vision_feats) vision_feats = vision_feats + self.pos_embed_vis hidden_state_lang, hidden_state_vis = self.multimodal( input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask ) pooled_features = self.pooler(hidden_state_lang) logits = self.cls_head(self.drop(pooled_features)) return logits