Source code for monai.apps.detection.utils.anchor_utils

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =========================================================================
# Adapted from https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py
# which has the following license...
# https://github.com/pytorch/vision/blob/main/LICENSE

# BSD 3-Clause License

# Copyright (c) Soumith Chintala 2016,
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.

# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.

# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
"""
This script is adapted from
https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py
"""

from __future__ import annotations

from typing import List, Sequence

import torch
from torch import Tensor, nn

from monai.utils import ensure_tuple
from monai.utils.misc import issequenceiterable
from monai.utils.module import look_up_option


[docs] class AnchorGenerator(nn.Module): """ This module is modified from torchvision to support both 2D and 3D images. Module that generates anchors for a set of feature maps and image sizes. The module support computing anchors at multiple sizes and aspect ratios per feature map. sizes and aspect_ratios should have the same number of elements, and it should correspond to the number of feature maps. sizes[i] and aspect_ratios[i] can have an arbitrary number of elements. For 2D images, anchor width and height w:h = 1:aspect_ratios[i,j] For 3D images, anchor width, height, and depth w:h:d = 1:aspect_ratios[i,j,0]:aspect_ratios[i,j,1] AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors per spatial location for feature map i. Args: sizes: base size of each anchor. len(sizes) is the number of feature maps, i.e., the number of output levels for the feature pyramid network (FPN). Each element of ``sizes`` is a Sequence which represents several anchor sizes for each feature map. aspect_ratios: the aspect ratios of anchors. ``len(aspect_ratios) = len(sizes)``. For 2D images, each element of ``aspect_ratios[i]`` is a Sequence of float. For 3D images, each element of ``aspect_ratios[i]`` is a Sequence of 2 value Sequence. indexing: choose from {``'ij'``, ``'xy'``}, optional, Matrix (``'ij'``, default and recommended) or Cartesian (``'xy'``) indexing of output. - Matrix (``'ij'``, default and recommended) indexing keeps the original axis not changed. - To use other monai detection components, please set ``indexing = 'ij'``. - Cartesian (``'xy'``) indexing swaps axis 0 and 1. - For 2D cases, monai ``AnchorGenerator(sizes, aspect_ratios, indexing='xy')`` and ``torchvision.models.detection.anchor_utils.AnchorGenerator(sizes, aspect_ratios)`` are equivalent. Reference:. https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py Example: .. code-block:: python # 2D example inputs for a 2-level feature maps sizes = ((10,12,14,16), (20,24,28,32)) base_aspect_ratios = (1., 0.5, 2.) aspect_ratios = (base_aspect_ratios, base_aspect_ratios) anchor_generator = AnchorGenerator(sizes, aspect_ratios) # 3D example inputs for a 2-level feature maps sizes = ((10,12,14,16), (20,24,28,32)) base_aspect_ratios = ((1., 1.), (1., 0.5), (0.5, 1.), (2., 2.)) aspect_ratios = (base_aspect_ratios, base_aspect_ratios) anchor_generator = AnchorGenerator(sizes, aspect_ratios) """ __annotations__ = {"cell_anchors": List[torch.Tensor]} def __init__( self, sizes: Sequence[Sequence[int]] = ((20, 30, 40),), aspect_ratios: Sequence = (((0.5, 1), (1, 0.5)),), indexing: str = "ij", ) -> None: super().__init__() if not issequenceiterable(sizes[0]): self.sizes = tuple((s,) for s in sizes) else: self.sizes = ensure_tuple(sizes) if not issequenceiterable(aspect_ratios[0]): aspect_ratios = (aspect_ratios,) * len(self.sizes) if len(self.sizes) != len(aspect_ratios): raise ValueError( "len(sizes) and len(aspect_ratios) should be equal. \ It represents the number of feature maps." ) spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1 spatial_dims = look_up_option(spatial_dims, [2, 3]) self.spatial_dims = spatial_dims self.indexing = look_up_option(indexing, ["ij", "xy"]) self.aspect_ratios = aspect_ratios self.cell_anchors = [ self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios) ] # This comment comes from torchvision. # TODO: https://github.com/pytorch/pytorch/issues/26792 # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # This method assumes aspect ratio = height / width for an anchor.
[docs] def generate_anchors( self, scales: Sequence, aspect_ratios: Sequence, dtype: torch.dtype = torch.float32, device: torch.device | None = None, ) -> torch.Tensor: """ Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map. Args: scales: a sequence which represents several anchor sizes for the current feature map. aspect_ratios: a sequence which represents several aspect_ratios for the current feature map. For 2D images, it is a Sequence of float aspect_ratios[j], anchor width and height w:h = 1:aspect_ratios[j]. For 3D images, it is a Sequence of 2 value Sequence aspect_ratios[j,0] and aspect_ratios[j,1], anchor width, height, and depth w:h:d = 1:aspect_ratios[j,0]:aspect_ratios[j,1] dtype: target data type of the output Tensor. device: target device to put the output Tensor data. Returns: For each s in scales, returns [s, s*aspect_ratios[j]] for 2D images, and [s, s*aspect_ratios[j,0],s*aspect_ratios[j,1]] for 3D images. """ scales_t = torch.as_tensor(scales, dtype=dtype, device=device) # sized (N,) aspect_ratios_t = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) # sized (M,) or (M,2) if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2): raise ValueError( f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \ {len(aspect_ratios_t.shape)-1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}." ) if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1): raise ValueError( f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \ shape (_,{self.spatial_dims-1}). But got aspect_ratios with shape {aspect_ratios_t.shape}." ) # if 2d, w:h = 1:aspect_ratios if self.spatial_dims == 2: area_scale = torch.sqrt(aspect_ratios_t) w_ratios = 1 / area_scale h_ratios = area_scale # if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1] elif self.spatial_dims == 3: area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0) w_ratios = 1 / area_scale h_ratios = aspect_ratios_t[:, 0] / area_scale d_ratios = aspect_ratios_t[:, 1] / area_scale ws = (w_ratios[:, None] * scales_t[None, :]).view(-1) hs = (h_ratios[:, None] * scales_t[None, :]).view(-1) if self.spatial_dims == 2: base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0 elif self.spatial_dims == 3: ds = (d_ratios[:, None] * scales_t[None, :]).view(-1) base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0 return base_anchors.round()
[docs] def set_cell_anchors(self, dtype: torch.dtype, device: torch.device) -> None: """ Convert each element in self.cell_anchors to ``dtype`` and send to ``device``. """ self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
[docs] def num_anchors_per_location(self): """ Return number of anchor shapes for each feature map. """ return [c.shape[0] for c in self.cell_anchors]
[docs] def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) -> list[Tensor]: """ Every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:spatial_dims) corresponds to a feature map. It outputs g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. Args: grid_sizes: spatial size of the feature maps strides: strides of the feature maps regarding to the original image Example: .. code-block:: python grid_sizes = [[100,100],[50,50]] strides = [[torch.tensor(2),torch.tensor(2)], [torch.tensor(4),torch.tensor(4)]] """ anchors = [] cell_anchors = self.cell_anchors if cell_anchors is None: raise AssertionError if not (len(grid_sizes) == len(strides) == len(cell_anchors)): raise ValueError( "Anchors should be Tuple[Tuple[int]] because each feature " "map could potentially have different sizes and aspect ratios. " "There needs to be a match between the number of " "feature maps passed and the number of sizes / aspect ratios specified." ) for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors): # for each feature map device = base_anchors.device # compute anchor centers regarding to the image. # shifts_centers is [x_center, y_center] or [x_center, y_center, z_center] shifts_centers = [ torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] for axis in range(self.spatial_dims) ] # to support torchscript, cannot directly use torch.meshgrid(shifts_centers). shifts_centers = list(torch.meshgrid(shifts_centers[: self.spatial_dims], indexing="ij")) for axis in range(self.spatial_dims): # each element of shifts_centers is sized (HW,) or (HWD,) shifts_centers[axis] = shifts_centers[axis].reshape(-1) # Expand to [x_center, y_center, x_center, y_center], # or [x_center, y_center, z_center, x_center, y_center, z_center] if self.indexing == "xy": # Cartesian ('xy') indexing swaps axis 0 and 1. shifts_centers[1], shifts_centers[0] = shifts_centers[0], shifts_centers[1] shifts = torch.stack(shifts_centers * 2, dim=1) # sized (HW,4) or (HWD,6) # For every (base anchor, output anchor) pair, # offset each zero-centered base anchor by the center of the output anchor. anchors.append( (shifts.view(-1, 1, self.spatial_dims * 2) + base_anchors.view(1, -1, self.spatial_dims * 2)).reshape( -1, self.spatial_dims * 2 ) # each element sized (AHWD,4) or (AHWD,6) ) return anchors
[docs] def forward(self, images: Tensor, feature_maps: list[Tensor]) -> list[Tensor]: """ Generate anchor boxes for each image. Args: images: sized (B, C, W, H) or (B, C, W, H, D) feature_maps: for FPN level i, feature_maps[i] is sized (B, C_i, W_i, H_i) or (B, C_i, W_i, H_i, D_i). This input argument does not have to be the actual feature maps. Any list variable with the same (C_i, W_i, H_i) or (C_i, W_i, H_i, D_i) as feature maps works. Return: A list with length of B. Each element represents the anchors for this image. The B elements are identical. Example: .. code-block:: python images = torch.zeros((3,1,128,128,128)) feature_maps = [torch.zeros((3,6,64,64,32)), torch.zeros((3,6,32,32,16))] anchor_generator(images, feature_maps) """ grid_sizes = [list(feature_map.shape[-self.spatial_dims :]) for feature_map in feature_maps] image_size = images.shape[-self.spatial_dims :] batchsize = images.shape[0] dtype, device = feature_maps[0].dtype, feature_maps[0].device strides = [ [ torch.tensor(image_size[axis] // g[axis], dtype=torch.int64, device=device) for axis in range(self.spatial_dims) ] for g in grid_sizes ] self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) anchors_per_image = torch.cat(list(anchors_over_all_feature_maps)) return [anchors_per_image] * batchsize
[docs] class AnchorGeneratorWithAnchorShape(AnchorGenerator): """ Module that generates anchors for a set of feature maps and image sizes, inherited from :py:class:`~monai.apps.detection.networks.utils.anchor_utils.AnchorGenerator` The module support computing anchors at multiple base anchor shapes per feature map. ``feature_map_scales`` should have the same number of elements with the number of feature maps. base_anchor_shapes can have an arbitrary number of elements. For 2D images, each element represents anchor width and height [w,h]. For 2D images, each element represents anchor width, height, and depth [w,h,d]. AnchorGenerator will output a set of ``len(base_anchor_shapes)`` anchors per spatial location for feature map ``i``. Args: feature_map_scales: scale of anchors for each feature map, i.e., each output level of the feature pyramid network (FPN). ``len(feature_map_scales)`` is the number of feature maps. ``scale[i]*base_anchor_shapes`` represents the anchor shapes for feature map ``i``. base_anchor_shapes: a sequence which represents several anchor shapes for one feature map. For N-D images, it is a Sequence of N value Sequence. indexing: choose from {'xy', 'ij'}, optional Cartesian ('xy') or matrix ('ij', default) indexing of output. Cartesian ('xy') indexing swaps axis 0 and 1, which is the setting inside torchvision. matrix ('ij', default) indexing keeps the original axis not changed. See also indexing in https://pytorch.org/docs/stable/generated/torch.meshgrid.html Example: .. code-block:: python # 2D example inputs for a 2-level feature maps feature_map_scales = (1, 2) base_anchor_shapes = ((10, 10), (6, 12), (12, 6)) anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes) # 3D example inputs for a 2-level feature maps feature_map_scales = (1, 2) base_anchor_shapes = ((10, 10, 10), (12, 12, 8), (10, 10, 6), (16, 16, 10)) anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes) """ __annotations__ = {"cell_anchors": List[torch.Tensor]} def __init__( self, feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8), base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = ( (32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48), ), indexing: str = "ij", ) -> None: nn.Module.__init__(self) spatial_dims = len(base_anchor_shapes[0]) spatial_dims = look_up_option(spatial_dims, [2, 3]) self.spatial_dims = spatial_dims self.indexing = look_up_option(indexing, ["ij", "xy"]) base_anchor_shapes_t = torch.Tensor(base_anchor_shapes) self.cell_anchors = [self.generate_anchors_using_shape(s * base_anchor_shapes_t) for s in feature_map_scales]
[docs] @staticmethod def generate_anchors_using_shape( anchor_shapes: torch.Tensor, dtype: torch.dtype = torch.float32, device: torch.device | None = None ) -> torch.Tensor: """ Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map. Args: anchor_shapes: [w, h] or [w, h, d], sized (N, spatial_dims), represents N anchor shapes for the current feature map. dtype: target data type of the output Tensor. device: target device to put the output Tensor data. Returns: For 2D images, returns [-w/2, -h/2, w/2, h/2]; For 3D images, returns [-w/2, -h/2, -d/2, w/2, h/2, d/2] """ half_anchor_shapes = anchor_shapes / 2.0 base_anchors = torch.cat([-half_anchor_shapes, half_anchor_shapes], dim=1) return base_anchors.round().to(dtype=dtype, device=device)