Source code for monai.networks.nets.netadapter

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Union

import torch

from monai.networks.layers import Conv, get_pool_layer


[docs]class NetAdapter(torch.nn.Module): """ Wrapper to replace the last layer of model by convolutional layer or FC layer. This module expects the output of `model layers[0: -2]` is a feature map with shape [B, C, spatial dims], then replace the model's last two layers with an optional `pooling` and a `conv` or `linear` layer. Args: model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision, like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc. more details: https://pytorch.org/vision/stable/models.html. n_classes: number of classes for the last classification layer. Default to 1. dim: number of spatial dimensions, default to 2. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer, the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`. default to `("avg", {"kernel_size": 7, "stride": 1})`. bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias, default to True. """ def __init__( self, model: torch.nn.Module, n_classes: int = 1, dim: int = 2, in_channels: Optional[int] = None, use_conv: bool = False, pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, ): super().__init__() layers = list(model.children()) orig_fc = layers[-1] in_channels_: int if in_channels is None: if not hasattr(orig_fc, "in_features"): raise ValueError("please specify the input channels of last layer with arg `in_channels`.") in_channels_ = orig_fc.in_features # type: ignore else: in_channels_ = in_channels if pool is None: self.pool = None # remove the last layer self.features = torch.nn.Sequential(*layers[:-1]) else: self.pool = get_pool_layer(name=pool, spatial_dims=dim) # remove the last 2 layers self.features = torch.nn.Sequential(*layers[:-2]) self.fc: Union[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d] if use_conv: # add 1x1 conv (it behaves like a FC layer) self.fc = Conv[Conv.CONV, dim]( in_channels=in_channels_, out_channels=n_classes, kernel_size=1, bias=bias, ) else: # remove the last Linear layer (fully connected) self.features = torch.nn.Sequential(*layers[:-1]) # replace the out_features of FC layer self.fc = torch.nn.Linear( in_features=in_channels_, out_features=n_classes, bias=bias, ) self.use_conv = use_conv
[docs] def forward(self, x): x = self.features(x) if self.pool is not None: x = self.pool(x) if not self.use_conv: x = torch.flatten(x, 1) x = self.fc(x) return x