# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
# which has the following license...
# https://github.com/pytorch/vision/blob/main/LICENSE
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
"""
from __future__ import annotations
import math
import warnings
from collections.abc import Callable, Sequence
from typing import Any, Dict
import torch
from torch import Tensor, nn
from monai.networks.blocks.backbone_fpn_utils import BackboneWithFPN, _resnet_fpn_extractor
from monai.networks.layers.factories import Conv
from monai.networks.nets import resnet
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
_validate_trainable_layers, _ = optional_import(
"torchvision.models.detection.backbone_utils", name="_validate_trainable_layers"
)
[docs]
class RetinaNetClassificationHead(nn.Module):
"""
A classification head for use in RetinaNet.
This head takes a list of feature maps as inputs, and outputs a list of classification maps.
Each output map has same spatial size with the corresponding input feature map,
and the number of output channel is num_anchors * num_classes.
Args:
in_channels: number of channels of the input feature
num_anchors: number of anchors to be predicted
num_classes: number of classes to be predicted
spatial_dims: spatial dimension of the network, should be 2 or 3.
prior_probability: prior probability to initialize classification convolutional layers.
"""
def __init__(
self, in_channels: int, num_anchors: int, num_classes: int, spatial_dims: int, prior_probability: float = 0.01
):
super().__init__()
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
conv = []
for _ in range(4):
conv.append(conv_type(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
conv.append(nn.GroupNorm(num_groups=8, num_channels=in_channels))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv)
for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
self.num_classes = num_classes
self.num_anchors = num_anchors
[docs]
def forward(self, x: list[Tensor]) -> list[Tensor]:
"""
It takes a list of feature maps as inputs, and outputs a list of classification maps.
Each output classification map has same spatial size with the corresponding input feature map,
and the number of output channel is num_anchors * num_classes.
Args:
x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.
Return:
cls_logits_maps, list of classification map. cls_logits_maps[i] is a
(B, num_anchors * num_classes, H_i, W_i) or (B, num_anchors * num_classes, H_i, W_i, D_i) Tensor.
"""
cls_logits_maps = []
if isinstance(x, Tensor):
feature_maps = [x]
else:
feature_maps = x
for features in feature_maps:
cls_logits = self.conv(features)
cls_logits = self.cls_logits(cls_logits)
cls_logits_maps.append(cls_logits)
if torch.isnan(cls_logits).any() or torch.isinf(cls_logits).any():
if torch.is_grad_enabled():
raise ValueError("cls_logits is NaN or Inf.")
else:
warnings.warn("cls_logits is NaN or Inf.")
return cls_logits_maps
[docs]
class RetinaNetRegressionHead(nn.Module):
"""
A regression head for use in RetinaNet.
This head takes a list of feature maps as inputs, and outputs a list of box regression maps.
Each output box regression map has same spatial size with the corresponding input feature map,
and the number of output channel is num_anchors * 2 * spatial_dims.
Args:
in_channels: number of channels of the input feature
num_anchors: number of anchors to be predicted
spatial_dims: spatial dimension of the network, should be 2 or 3.
"""
def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):
super().__init__()
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
conv = []
for _ in range(4):
conv.append(conv_type(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
conv.append(nn.GroupNorm(num_groups=8, num_channels=in_channels))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv)
self.bbox_reg = conv_type(in_channels, num_anchors * 2 * spatial_dims, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
torch.nn.init.zeros_(self.bbox_reg.bias)
for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
[docs]
def forward(self, x: list[Tensor]) -> list[Tensor]:
"""
It takes a list of feature maps as inputs, and outputs a list of box regression maps.
Each output box regression map has same spatial size with the corresponding input feature map,
and the number of output channel is num_anchors * 2 * spatial_dims.
Args:
x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.
Return:
box_regression_maps, list of box regression map. cls_logits_maps[i] is a
(B, num_anchors * 2 * spatial_dims, H_i, W_i) or (B, num_anchors * 2 * spatial_dims, H_i, W_i, D_i) Tensor.
"""
box_regression_maps = []
if isinstance(x, Tensor):
feature_maps = [x]
else:
feature_maps = x
for features in feature_maps:
box_regression = self.conv(features)
box_regression = self.bbox_reg(box_regression)
box_regression_maps.append(box_regression)
if torch.isnan(box_regression).any() or torch.isinf(box_regression).any():
if torch.is_grad_enabled():
raise ValueError("box_regression is NaN or Inf.")
else:
warnings.warn("box_regression is NaN or Inf.")
return box_regression_maps
[docs]
class RetinaNet(nn.Module):
"""
The network used in RetinaNet.
It takes an image tensor as inputs, and outputs either 1) a dictionary ``head_outputs``.
``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
or 2) a list of 2N tensors ``head_outputs``, with first N tensors being the predicted
classification maps and second N tensors being the predicted box regression maps.
Args:
spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.
num_classes: number of output classes of the model (excluding the background).
num_anchors: number of anchors at each location.
feature_extractor: a network that outputs feature maps from the input images,
each feature map corresponds to a different resolution.
Its output can have a format of Tensor, Dict[Any, Tensor], or Sequence[Tensor].
It can be the output of ``resnet_fpn_feature_extractor(*args, **kwargs)``.
size_divisible: the spatial size of the network input should be divisible by size_divisible,
decided by the feature_extractor.
use_list_output: default False. If False, the network outputs a dictionary ``head_outputs``,
``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
If True, the network outputs a list of 2N tensors ``head_outputs``, with first N tensors being
the predicted classification maps and second N tensors being the predicted box regression maps.
Example:
.. code-block:: python
from monai.networks.nets import resnet
spatial_dims = 3 # 3D network
conv1_t_stride = (2,2,1) # stride of first convolutional layer in backbone
backbone = resnet.ResNet(
spatial_dims = spatial_dims,
block = resnet.ResNetBottleneck,
layers = [3, 4, 6, 3],
block_inplanes = resnet.get_inplanes(),
n_input_channels= 1,
conv1_t_stride = conv1_t_stride,
conv1_t_size = (7,7,7),
)
# This feature_extractor outputs 4-level feature maps.
# number of output feature maps is len(returned_layers)+1
returned_layers = [1,2,3] # returned layer from feature pyramid network
feature_extractor = resnet_fpn_feature_extractor(
backbone = backbone,
spatial_dims = spatial_dims,
pretrained_backbone = False,
trainable_backbone_layers = None,
returned_layers = returned_layers,
)
# This feature_extractor requires input image spatial size
# to be divisible by (32, 32, 16).
size_divisible = tuple(2*s*2**max(returned_layers) for s in conv1_t_stride)
model = RetinaNet(
spatial_dims = spatial_dims,
num_classes = 5,
num_anchors = 6,
feature_extractor=feature_extractor,
size_divisible = size_divisible,
).to(device)
result = model(torch.rand(2, 1, 128,128,128))
cls_logits_maps = result["classification"] # a list of len(returned_layers)+1 Tensor
box_regression_maps = result["box_regression"] # a list of len(returned_layers)+1 Tensor
"""
def __init__(
self,
spatial_dims: int,
num_classes: int,
num_anchors: int,
feature_extractor: nn.Module,
size_divisible: Sequence[int] | int = 1,
use_list_output: bool = False,
):
super().__init__()
self.spatial_dims = look_up_option(spatial_dims, supported=[1, 2, 3])
self.num_classes = num_classes
self.size_divisible = ensure_tuple_rep(size_divisible, self.spatial_dims)
self.use_list_output = use_list_output
if not hasattr(feature_extractor, "out_channels"):
raise ValueError(
"feature_extractor should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)"
)
self.feature_extractor = feature_extractor
self.feature_map_channels: int = self.feature_extractor.out_channels
self.num_anchors = num_anchors
self.classification_head = RetinaNetClassificationHead(
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
)
self.regression_head = RetinaNetRegressionHead(
self.feature_map_channels, self.num_anchors, spatial_dims=self.spatial_dims
)
self.cls_key: str = "classification"
self.box_reg_key: str = "box_regression"
[docs]
def forward(self, images: Tensor) -> Any:
"""
It takes an image tensor as inputs, and outputs predicted classification maps
and predicted box regression maps in ``head_outputs``.
Args:
images: input images, sized (B, img_channels, H, W) or (B, img_channels, H, W, D).
Return:
1) If self.use_list_output is False, output a dictionary ``head_outputs`` with
keys including self.cls_key and self.box_reg_key.
``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
2) if self.use_list_output is True, outputs a list of 2N tensors ``head_outputs``, with first N tensors being
the predicted classification maps and second N tensors being the predicted box regression maps.
"""
# compute features maps list from the input images.
features = self.feature_extractor(images)
if isinstance(features, Tensor):
feature_maps = [features]
elif torch.jit.isinstance(features, Dict[str, Tensor]):
feature_maps = list(features.values())
else:
feature_maps = list(features)
if not isinstance(feature_maps[0], Tensor):
raise ValueError("feature_extractor output format must be Tensor, Dict[str, Tensor], or Sequence[Tensor].")
# compute classification and box regression maps from the feature maps
# expandable for mask prediction in the future
if not self.use_list_output:
# output dict
head_outputs = {self.cls_key: self.classification_head(feature_maps)}
head_outputs[self.box_reg_key] = self.regression_head(feature_maps)
return head_outputs
else:
# output list of tensor, first half is classification, second half is box regression
head_outputs_sequence = self.classification_head(feature_maps) + self.regression_head(feature_maps)
return head_outputs_sequence