Source code for monai.networks.nets.efficientnet

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
import operator
import re
from functools import reduce
from typing import NamedTuple

import torch
from torch import nn
from torch.utils import model_zoo

from monai.networks.blocks import BaseEncoder
from monai.networks.layers.factories import Act, Conv, Pad, Pool
from monai.networks.layers.utils import get_norm_layer
from monai.utils.module import look_up_option

__all__ = [
    "EfficientNet",
    "EfficientNetBN",
    "get_efficientnet_image_size",
    "drop_connect",
    "EfficientNetBNFeatures",
    "BlockArgs",
    "EfficientNetEncoder",
]

efficientnet_params = {
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
    "efficientnet-b8": (2.2, 3.6, 672, 0.5, 0.2),
    "efficientnet-l2": (4.3, 5.3, 800, 0.5, 0.2),
}

url_map = {
    "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
    "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
    "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
    "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
    "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
    "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
    "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
    "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
    # trained with adversarial examples, simplify the name to decrease string length
    "b0-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth",
    "b1-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth",
    "b2-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth",
    "b3-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth",
    "b4-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth",
    "b5-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth",
    "b6-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth",
    "b7-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth",
    "b8-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth",
}


class MBConvBlock(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        image_size: list[int],
        expand_ratio: int,
        se_ratio: float | None,
        id_skip: bool | None = True,
        norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}),
        drop_connect_rate: float | None = 0.2,
    ) -> None:
        """
        Mobile Inverted Residual Bottleneck Block.

        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernel_size: size of the kernel for conv ops.
            stride: stride to use for conv ops.
            image_size: input image resolution.
            expand_ratio: expansion ratio for inverted bottleneck.
            se_ratio: squeeze-excitation ratio for se layers.
            id_skip: whether to use skip connection.
            norm: feature normalization type and arguments. Defaults to batch norm.
            drop_connect_rate: dropconnect rate for drop connection (individual weights) layers.

        References:
            [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
            [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
            [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
        """
        super().__init__()

        # select the type of N-Dimensional layers to use
        # these are based on spatial dims and selected from MONAI factories
        conv_type = Conv["conv", spatial_dims]
        adaptivepool_type = Pool["adaptiveavg", spatial_dims]

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.id_skip = id_skip
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.drop_connect_rate = drop_connect_rate

        if (se_ratio is not None) and (0.0 < se_ratio <= 1.0):
            self.has_se = True
            self.se_ratio = se_ratio
        else:
            self.has_se = False

        # Expansion phase (Inverted Bottleneck)
        inp = in_channels  # number of input channels
        oup = in_channels * expand_ratio  # number of output channels
        if self.expand_ratio != 1:
            self._expand_conv = conv_type(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
            self._expand_conv_padding = _make_same_padder(self._expand_conv, image_size)

            self._bn0 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=oup)
        else:
            # need to have the following to fix JIT error:
            #   "Module 'MBConvBlock' has no attribute '_expand_conv'"

            # FIXME: find a better way to bypass JIT error
            self._expand_conv = nn.Identity()
            self._expand_conv_padding = nn.Identity()
            self._bn0 = nn.Identity()

        # Depthwise convolution phase
        self._depthwise_conv = conv_type(
            in_channels=oup,
            out_channels=oup,
            groups=oup,  # groups makes it depthwise
            kernel_size=kernel_size,
            stride=self.stride,
            bias=False,
        )
        self._depthwise_conv_padding = _make_same_padder(self._depthwise_conv, image_size)
        self._bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=oup)
        image_size = _calculate_output_image_size(image_size, self.stride)

        # Squeeze and Excitation layer, if desired
        if self.has_se:
            self._se_adaptpool = adaptivepool_type(1)
            num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
            self._se_reduce = conv_type(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
            self._se_reduce_padding = _make_same_padder(self._se_reduce, [1, 1])
            self._se_expand = conv_type(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
            self._se_expand_padding = _make_same_padder(self._se_expand, [1, 1])

        # Pointwise convolution phase
        final_oup = out_channels
        self._project_conv = conv_type(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
        self._project_conv_padding = _make_same_padder(self._project_conv, image_size)
        self._bn2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=final_oup)

        # swish activation to use - using memory efficient swish by default
        # can be switched to normal swish using self.set_swish() function call
        self._swish = Act["memswish"](inplace=True)

    def forward(self, inputs: torch.Tensor):
        """MBConvBlock"s forward function.

        Args:
            inputs: Input tensor.

        Returns:
            Output of this block after processing.
        """
        # Expansion and Depthwise Convolution
        x = inputs
        if self.expand_ratio != 1:
            x = self._expand_conv(self._expand_conv_padding(x))
            x = self._bn0(x)
            x = self._swish(x)

        x = self._depthwise_conv(self._depthwise_conv_padding(x))
        x = self._bn1(x)
        x = self._swish(x)

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = self._se_adaptpool(x)
            x_squeezed = self._se_reduce(self._se_reduce_padding(x_squeezed))
            x_squeezed = self._swish(x_squeezed)
            x_squeezed = self._se_expand(self._se_expand_padding(x_squeezed))
            x = torch.sigmoid(x_squeezed) * x

        # Pointwise Convolution
        x = self._project_conv(self._project_conv_padding(x))
        x = self._bn2(x)

        # Skip connection and drop connect
        if self.id_skip and self.stride == 1 and self.in_channels == self.out_channels:
            # the combination of skip connection and drop connect brings about stochastic depth.
            if self.drop_connect_rate:
                x = drop_connect(x, p=self.drop_connect_rate, training=self.training)
            x = x + inputs  # skip connection
        return x

    def set_swish(self, memory_efficient: bool = True) -> None:
        """Sets swish function as memory efficient (for training) or standard (for export).

        Args:
            memory_efficient (bool): Whether to use memory-efficient version of swish.
        """
        self._swish = Act["memswish"](inplace=True) if memory_efficient else Act["swish"](alpha=1.0)


[docs]class EfficientNet(nn.Module):
[docs] def __init__( self, blocks_args_str: list[str], spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, width_coefficient: float = 1.0, depth_coefficient: float = 1.0, dropout_rate: float = 0.2, image_size: int = 224, norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), drop_connect_rate: float = 0.2, depth_divisor: int = 8, ) -> None: """ EfficientNet based on `Rethinking Model Scaling for Convolutional Neural Networks <https://arxiv.org/pdf/1905.11946.pdf>`_. Adapted from `EfficientNet-PyTorch <https://github.com/lukemelas/EfficientNet-PyTorch>`_. Args: blocks_args_str: block definitions. spatial_dims: number of spatial dimensions. in_channels: number of input channels. num_classes: number of output classes. width_coefficient: width multiplier coefficient (w in paper). depth_coefficient: depth multiplier coefficient (d in paper). dropout_rate: dropout rate for dropout layers. image_size: input image resolution. norm: feature normalization type and arguments. drop_connect_rate: dropconnect rate for drop connection (individual weights) layers. depth_divisor: depth divisor for channel rounding. """ super().__init__() if spatial_dims not in (1, 2, 3): raise ValueError("spatial_dims can only be 1, 2 or 3.") # select the type of N-Dimensional layers to use # these are based on spatial dims and selected from MONAI factories conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv["conv", spatial_dims] adaptivepool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[ "adaptiveavg", spatial_dims ] # decode blocks args into arguments for MBConvBlock blocks_args = [BlockArgs.from_string(s) for s in blocks_args_str] # checks for successful decoding of blocks_args_str if not isinstance(blocks_args, list): raise ValueError("blocks_args must be a list") if blocks_args == []: raise ValueError("block_args must be non-empty") self._blocks_args = blocks_args self.num_classes = num_classes self.in_channels = in_channels self.drop_connect_rate = drop_connect_rate # expand input image dimensions to list current_image_size = [image_size] * spatial_dims # Stem stride = 2 out_channels = _round_filters(32, width_coefficient, depth_divisor) # number of output channels self._conv_stem = conv_type(self.in_channels, out_channels, kernel_size=3, stride=stride, bias=False) self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size) self._bn0 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels) current_image_size = _calculate_output_image_size(current_image_size, stride) # build MBConv blocks num_blocks = 0 self._blocks = nn.Sequential() self.extract_stacks = [] # update baseline blocks to input/output filters and number of repeats based on width and depth multipliers. for idx, block_args in enumerate(self._blocks_args): block_args = block_args._replace( input_filters=_round_filters(block_args.input_filters, width_coefficient, depth_divisor), output_filters=_round_filters(block_args.output_filters, width_coefficient, depth_divisor), num_repeat=_round_repeats(block_args.num_repeat, depth_coefficient), ) self._blocks_args[idx] = block_args # calculate the total number of blocks - needed for drop_connect estimation num_blocks += block_args.num_repeat if block_args.stride > 1: self.extract_stacks.append(idx) self.extract_stacks.append(len(self._blocks_args)) # create and add MBConvBlocks to self._blocks idx = 0 # block index counter for stack_idx, block_args in enumerate(self._blocks_args): blk_drop_connect_rate = self.drop_connect_rate # scale drop connect_rate if blk_drop_connect_rate: blk_drop_connect_rate *= float(idx) / num_blocks sub_stack = nn.Sequential() # the first block needs to take care of stride and filter size increase. sub_stack.add_module( str(idx), MBConvBlock( spatial_dims=spatial_dims, in_channels=block_args.input_filters, out_channels=block_args.output_filters, kernel_size=block_args.kernel_size, stride=block_args.stride, image_size=current_image_size, expand_ratio=block_args.expand_ratio, se_ratio=block_args.se_ratio, id_skip=block_args.id_skip, norm=norm, drop_connect_rate=blk_drop_connect_rate, ), ) idx += 1 # increment blocks index counter current_image_size = _calculate_output_image_size(current_image_size, block_args.stride) if block_args.num_repeat > 1: # modify block_args to keep same output size block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) # add remaining block repeated num_repeat times for _ in range(block_args.num_repeat - 1): blk_drop_connect_rate = self.drop_connect_rate # scale drop connect_rate if blk_drop_connect_rate: blk_drop_connect_rate *= float(idx) / num_blocks # add blocks sub_stack.add_module( str(idx), MBConvBlock( spatial_dims=spatial_dims, in_channels=block_args.input_filters, out_channels=block_args.output_filters, kernel_size=block_args.kernel_size, stride=block_args.stride, image_size=current_image_size, expand_ratio=block_args.expand_ratio, se_ratio=block_args.se_ratio, id_skip=block_args.id_skip, norm=norm, drop_connect_rate=blk_drop_connect_rate, ), ) idx += 1 # increment blocks index counter self._blocks.add_module(str(stack_idx), sub_stack) # sanity check to see if len(self._blocks) equal expected num_blocks if idx != num_blocks: raise ValueError("total number of blocks created != num_blocks") # Head head_in_channels = block_args.output_filters out_channels = _round_filters(1280, width_coefficient, depth_divisor) self._conv_head = conv_type(head_in_channels, out_channels, kernel_size=1, bias=False) self._conv_head_padding = _make_same_padder(self._conv_head, current_image_size) self._bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels) # final linear layer self._avg_pooling = adaptivepool_type(1) self._dropout = nn.Dropout(dropout_rate) self._fc = nn.Linear(out_channels, self.num_classes) # swish activation to use - using memory efficient swish by default # can be switched to normal swish using self.set_swish() function call self._swish = Act["memswish"]() # initialize weights using Tensorflow's init method from official impl. self._initialize_weights()
[docs] def set_swish(self, memory_efficient: bool = True) -> None: """ Sets swish function as memory efficient (for training) or standard (for JIT export). Args: memory_efficient: whether to use memory-efficient version of swish. """ self._swish = Act["memswish"]() if memory_efficient else Act["swish"](alpha=1.0) for sub_stack in self._blocks: for block in sub_stack: block.set_swish(memory_efficient)
[docs] def forward(self, inputs: torch.Tensor): """ Args: inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. Returns: a torch Tensor of classification prediction in shape ``(Batch, num_classes)``. """ # Stem x = self._conv_stem(self._conv_stem_padding(inputs)) x = self._swish(self._bn0(x)) # Blocks x = self._blocks(x) # Head x = self._conv_head(self._conv_head_padding(x)) x = self._swish(self._bn1(x)) # Pooling and final linear layer x = self._avg_pooling(x) x = x.flatten(start_dim=1) x = self._dropout(x) x = self._fc(x) return x
def _initialize_weights(self) -> None: """ Args: None, initializes weights for conv/linear/batchnorm layers following weight init methods from `official Tensorflow EfficientNet implementation <https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L61>`_. Adapted from `EfficientNet-PyTorch's init method <https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/efficientnet_builder.py>`_. """ for _, m in self.named_modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): fan_out = reduce(operator.mul, m.kernel_size, 1) * m.out_channels m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): fan_out = m.weight.size(0) fan_in = 0 init_range = 1.0 / math.sqrt(fan_in + fan_out) m.weight.data.uniform_(-init_range, init_range) m.bias.data.zero_()
[docs]class EfficientNetBN(EfficientNet):
[docs] def __init__( self, model_name: str, pretrained: bool = True, progress: bool = True, spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), adv_prop: bool = False, ) -> None: """ Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models model_name is mandatory argument as there is no EfficientNetBN itself, it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model Args: model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2]. pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch norm is used. progress: whether to show download progress for pretrained weights download. spatial_dims: number of spatial dimensions. in_channels: number of input channels. num_classes: number of output classes. norm: feature normalization type and arguments. adv_prop: whether to use weights trained with adversarial examples. This argument only works when `pretrained` is `True`. Examples:: # for pretrained spatial 2D ImageNet >>> image_size = get_efficientnet_image_size("efficientnet-b0") >>> inputs = torch.rand(1, 3, image_size, image_size) >>> model = EfficientNetBN("efficientnet-b0", pretrained=True) >>> model.eval() >>> outputs = model(inputs) # create spatial 2D >>> model = EfficientNetBN("efficientnet-b0", spatial_dims=2) # create spatial 3D >>> model = EfficientNetBN("efficientnet-b0", spatial_dims=3) # create EfficientNetB7 for spatial 2D >>> model = EfficientNetBN("efficientnet-b7", spatial_dims=2) """ # block args blocks_args_str = [ "r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", "r2_k5_s22_e6_i24_o40_se0.25", "r3_k3_s22_e6_i40_o80_se0.25", "r3_k5_s11_e6_i80_o112_se0.25", "r4_k5_s22_e6_i112_o192_se0.25", "r1_k3_s11_e6_i192_o320_se0.25", ] # check if model_name is valid model if model_name not in efficientnet_params: model_name_string = ", ".join(efficientnet_params.keys()) raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") # get network parameters weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights super().__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, num_classes=num_classes, width_coefficient=weight_coeff, depth_coefficient=depth_coeff, dropout_rate=dropout_rate, image_size=image_size, drop_connect_rate=dropconnect_rate, norm=norm, ) # only pretrained for when `spatial_dims` is 2 if pretrained and (spatial_dims == 2): _load_state_dict(self, model_name, progress, adv_prop)
[docs]class EfficientNetBNFeatures(EfficientNet):
[docs] def __init__( self, model_name: str, pretrained: bool = True, progress: bool = True, spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, norm: str | tuple = ("batch", {"eps": 1e-3, "momentum": 0.01}), adv_prop: bool = False, ) -> None: """ Initialize EfficientNet-B0 to EfficientNet-B7 models as a backbone, the backbone can be used as an encoder for segmentation and objection models. Compared with the class `EfficientNetBN`, the only different place is the forward function. This class refers to `PyTorch image models <https://github.com/rwightman/pytorch-image-models>`_. """ blocks_args_str = [ "r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", "r2_k5_s22_e6_i24_o40_se0.25", "r3_k3_s22_e6_i40_o80_se0.25", "r3_k5_s11_e6_i80_o112_se0.25", "r4_k5_s22_e6_i112_o192_se0.25", "r1_k3_s11_e6_i192_o320_se0.25", ] # check if model_name is valid model if model_name not in efficientnet_params: model_name_string = ", ".join(efficientnet_params.keys()) raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") # get network parameters weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] # create model and initialize random weights super().__init__( blocks_args_str=blocks_args_str, spatial_dims=spatial_dims, in_channels=in_channels, num_classes=num_classes, width_coefficient=weight_coeff, depth_coefficient=depth_coeff, dropout_rate=dropout_rate, image_size=image_size, drop_connect_rate=dropconnect_rate, norm=norm, ) # only pretrained for when `spatial_dims` is 2 if pretrained and (spatial_dims == 2): _load_state_dict(self, model_name, progress, adv_prop)
[docs] def forward(self, inputs: torch.Tensor): """ Args: inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. Returns: a list of torch Tensors. """ # Stem x = self._conv_stem(self._conv_stem_padding(inputs)) x = self._swish(self._bn0(x)) features = [] if 0 in self.extract_stacks: features.append(x) for i, block in enumerate(self._blocks): x = block(x) if i + 1 in self.extract_stacks: features.append(x) return features
class EfficientNetEncoder(EfficientNetBNFeatures, BaseEncoder): """ Wrap the original efficientnet to an encoder for flexible-unet. """ backbone_names = [ "efficientnet-b0", "efficientnet-b1", "efficientnet-b2", "efficientnet-b3", "efficientnet-b4", "efficientnet-b5", "efficientnet-b6", "efficientnet-b7", "efficientnet-b8", "efficientnet-l2", ] @classmethod def get_encoder_parameters(cls) -> list[dict]: """ Get the initialization parameter for efficientnet backbones. """ parameter_list = [] for backbone_name in cls.backbone_names: parameter_list.append( { "model_name": backbone_name, "pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 3, "num_classes": 1000, "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), "adv_prop": "ap" in backbone_name, } ) return parameter_list @classmethod def num_channels_per_output(cls) -> list[tuple[int, ...]]: """ Get number of efficientnet backbone output feature maps' channel. """ return [ (16, 24, 40, 112, 320), (16, 24, 40, 112, 320), (16, 24, 48, 120, 352), (24, 32, 48, 136, 384), (24, 32, 56, 160, 448), (24, 40, 64, 176, 512), (32, 40, 72, 200, 576), (32, 48, 80, 224, 640), (32, 56, 88, 248, 704), (72, 104, 176, 480, 1376), ] @classmethod def num_outputs(cls) -> list[int]: """ Get number of efficientnet backbone output feature maps. Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 10`. """ return [5] * 10 @classmethod def get_encoder_names(cls) -> list[str]: """ Get names of efficient backbone. """ return cls.backbone_names def get_efficientnet_image_size(model_name: str) -> int: """ Get the input image size for a given efficientnet model. Args: model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b7]. Returns: Image size for single spatial dimension as integer. """ # check if model_name is valid model if model_name not in efficientnet_params: model_name_string = ", ".join(efficientnet_params.keys()) raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") # return input image size (all dims equal so only need to return for one dim) _, _, res, _, _ = efficientnet_params[model_name] return res def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor: """ Drop connect layer that drops individual connections. Differs from dropout as dropconnect drops connections instead of whole neurons as in dropout. Based on `Deep Networks with Stochastic Depth <https://arxiv.org/pdf/1603.09382.pdf>`_. Adapted from `Official Tensorflow EfficientNet utils <https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py>`_. This function is generalized for MONAI's N-Dimensional spatial activations e.g. 1D activations [B, C, H], 2D activations [B, C, H, W] and 3D activations [B, C, H, W, D] Args: inputs: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims. p: probability to use for dropping connections. training: whether in training or evaluation mode. Returns: output: output tensor after applying drop connection. """ if p < 0.0 or p > 1.0: raise ValueError(f"p must be in range of [0, 1], found {p}") # eval mode: drop_connect is switched off - so return input without modifying if not training: return inputs # train mode: calculate and apply drop_connect batch_size: int = inputs.shape[0] keep_prob: float = 1 - p num_dims: int = len(inputs.shape) - 2 # build dimensions for random tensor, use num_dims to populate appropriate spatial dims random_tensor_shape: list[int] = [batch_size, 1] + [1] * num_dims # generate binary_tensor mask according to probability (p for 0, 1-p for 1) random_tensor: torch.Tensor = torch.rand(random_tensor_shape, dtype=inputs.dtype, device=inputs.device) random_tensor += keep_prob # round to form binary tensor binary_tensor: torch.Tensor = torch.floor(random_tensor) # drop connect using binary tensor output: torch.Tensor = inputs / keep_prob * binary_tensor return output def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool) -> None: if adv_prop: arch = arch.split("efficientnet-")[-1] + "-ap" model_url = look_up_option(arch, url_map, None) if model_url is None: print(f"pretrained weights of {arch} is not provided") else: # load state dict from url model_url = url_map[arch] pretrain_state_dict = model_zoo.load_url(model_url, progress=progress) model_state_dict = model.state_dict() pattern = re.compile(r"(.+)\.\d+(\.\d+\..+)") for key, value in model_state_dict.items(): pretrain_key = re.sub(pattern, r"\1\2", key) if pretrain_key in pretrain_state_dict and value.shape == pretrain_state_dict[pretrain_key].shape: model_state_dict[key] = pretrain_state_dict[pretrain_key] model.load_state_dict(model_state_dict) def _get_same_padding_conv_nd( image_size: list[int], kernel_size: tuple[int, ...], dilation: tuple[int, ...], stride: tuple[int, ...] ) -> list[int]: """ Helper for getting padding (nn.ConstantPadNd) to be used to get SAME padding conv operations similar to Tensorflow's SAME padding. This function is generalized for MONAI's N-Dimensional spatial operations (e.g. Conv1D, Conv2D, Conv3D) Args: image_size: input image/feature spatial size. kernel_size: conv kernel's spatial size. dilation: conv dilation rate for Atrous conv. stride: stride for conv operation. Returns: paddings for ConstantPadNd padder to be used on input tensor to conv op. """ # get number of spatial dimensions, corresponds to kernel size length num_dims = len(kernel_size) # additional checks to populate dilation and stride (in case they are single entry tuples) if len(dilation) == 1: dilation = dilation * num_dims if len(stride) == 1: stride = stride * num_dims # equation to calculate (pad^+ + pad^-) size _pad_size: list[int] = [ max((math.ceil(_i_s / _s) - 1) * _s + (_k_s - 1) * _d + 1 - _i_s, 0) for _i_s, _k_s, _d, _s in zip(image_size, kernel_size, dilation, stride) ] # distribute paddings into pad^+ and pad^- following Tensorflow's same padding strategy _paddings: list[tuple[int, int]] = [(_p // 2, _p - _p // 2) for _p in _pad_size] # unroll list of tuples to tuples, and then to list # reversed as nn.ConstantPadNd expects paddings starting with last dimension _paddings_ret: list[int] = [outer for inner in reversed(_paddings) for outer in inner] return _paddings_ret def _make_same_padder(conv_op: nn.Conv1d | nn.Conv2d | nn.Conv3d, image_size: list[int]): """ Helper for initializing ConstantPadNd with SAME padding similar to Tensorflow. Uses output of _get_same_padding_conv_nd() to get the padding size. This function is generalized for MONAI's N-Dimensional spatial operations (e.g. Conv1D, Conv2D, Conv3D) Args: conv_op: nn.ConvNd operation to extract parameters for op from image_size: input image/feature spatial size Returns: If padding required then nn.ConstandNd() padder initialized to paddings otherwise nn.Identity() """ # calculate padding required padding: list[int] = _get_same_padding_conv_nd(image_size, conv_op.kernel_size, conv_op.dilation, conv_op.stride) # initialize and return padder padder = Pad["constantpad", len(padding) // 2] if sum(padding) > 0: return padder(padding=padding, value=0.0) return nn.Identity() def _round_filters(filters: int, width_coefficient: float | None, depth_divisor: float) -> int: """ Calculate and round number of filters based on width coefficient multiplier and depth divisor. Args: filters: number of input filters. width_coefficient: width coefficient for model. depth_divisor: depth divisor to use. Returns: new_filters: new number of filters after calculation. """ if not width_coefficient: return filters multiplier: float = width_coefficient divisor: float = depth_divisor filters_float: float = filters * multiplier # follow the formula transferred from official TensorFlow implementation new_filters: float = max(divisor, int(filters_float + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters_float: # prevent rounding by more than 10% new_filters += divisor return int(new_filters) def _round_repeats(repeats: int, depth_coefficient: float | None) -> int: """ Re-calculate module's repeat number of a block based on depth coefficient multiplier. Args: repeats: number of original repeats. depth_coefficient: depth coefficient for model. Returns: new repeat: new number of repeat after calculating. """ if not depth_coefficient: return repeats # follow the formula transferred from official TensorFlow impl. return int(math.ceil(depth_coefficient * repeats)) def _calculate_output_image_size(input_image_size: list[int], stride: int | tuple[int]): """ Calculates the output image size when using _make_same_padder with a stride. Required for static padding. Args: input_image_size: input image/feature spatial size. stride: Conv2d operation"s stride. Returns: output_image_size: output image/feature spatial size. """ # checks to extract integer stride in case tuple was received if isinstance(stride, tuple): all_strides_equal = all(stride[0] == s for s in stride) if not all_strides_equal: raise ValueError(f"unequal strides are not possible, got {stride}") stride = stride[0] # return output image size return [int(math.ceil(im_sz / stride)) for im_sz in input_image_size]
[docs]class BlockArgs(NamedTuple): """ BlockArgs object to assist in decoding string notation of arguments for MBConvBlock definition. """ num_repeat: int kernel_size: int stride: int expand_ratio: int input_filters: int output_filters: int id_skip: bool se_ratio: float | None = None
[docs] @staticmethod def from_string(block_string: str): """ Get a BlockArgs object from a string notation of arguments. Args: block_string (str): A string notation of arguments. Examples: "r1_k3_s11_e1_i32_o16_se0.25". Returns: BlockArgs: namedtuple defined at the top of this function. """ ops = block_string.split("_") options = {} for op in ops: splits = re.split(r"(\d.*)", op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # check stride stride_check = ( ("s" in options and len(options["s"]) == 1) or (len(options["s"]) == 2 and options["s"][0] == options["s"][1]) or (len(options["s"]) == 3 and options["s"][0] == options["s"][1] and options["s"][0] == options["s"][2]) ) if not stride_check: raise ValueError("invalid stride option received") return BlockArgs( num_repeat=int(options["r"]), kernel_size=int(options["k"]), stride=int(options["s"][0]), expand_ratio=int(options["e"]), input_filters=int(options["i"]), output_filters=int(options["o"]), id_skip=("noskip" not in block_string), se_ratio=float(options["se"]) if "se" in options else None, )
[docs] def to_string(self): """ Return a block string notation for current BlockArgs object Returns: A string notation of BlockArgs object arguments. Example: "r1_k3_s11_e1_i32_o16_se0.25_noskip". """ string = ( f"r{self.num_repeat}_k{self.kernel_size}_s{self.stride}{self.stride}" f"_e{self.expand_ratio}_i{self.input_filters}_o{self.output_filters}" f"_se{self.se_ratio}" ) if not self.id_skip: string += "_noskip" return string