Source code for monai.apps.detection.utils.ATSS_matcher

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =========================================================================
# Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py
# which has the following license...
# https://github.com/MIC-DKFZ/nnDetection/blob/main/LICENSE
#
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =========================================================================
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py
# which has the following license...
# https://github.com/pytorch/vision/blob/main/LICENSE
#
# BSD 3-Clause License

# Copyright (c) Soumith Chintala 2016,
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.

# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.

# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
The functions in this script are adapted from nnDetection,
https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py
which is adapted from torchvision.

These are the changes compared with nndetection:
1) comments and docstrings;
2) reformat;
3) add a debug option to ATSSMatcher to help the users to tune parameters;
4) add a corner case return in ATSSMatcher.compute_matches;
5) add support for float16 cpu
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import TypeVar

import torch
from torch import Tensor

from monai.data.box_utils import COMPUTE_DTYPE, box_iou, boxes_center_distance, centers_in_boxes
from monai.utils.type_conversion import convert_to_tensor

# -INF should be smaller than the lower bound of similarity_fn output.
INF = float("inf")


[docs] class Matcher(ABC): """ Base class of Matcher, which matches boxes and anchors to each other Args: similarity_fn: function for similarity computation between boxes and anchors """ BELOW_LOW_THRESHOLD: int = -1 BETWEEN_THRESHOLDS: int = -2 def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): # type: ignore self.similarity_fn = similarity_fn def __call__( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute matches for a single image Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. num_anchors_per_level: number of anchors per feature pyramid level num_anchors_per_loc: number of anchors per position Returns: - matrix which contains the similarity from each boxes to each anchor [N, M] - vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] Note: ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D. """ if boxes.numel() == 0: # no ground truth num_anchors = anchors.shape[0] match_quality_matrix = torch.tensor([]).to(anchors) matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD) return match_quality_matrix, matches # at least one ground truth return self.compute_matches( boxes=boxes, anchors=anchors, num_anchors_per_level=num_anchors_per_level, num_anchors_per_loc=num_anchors_per_loc, )
[docs] @abstractmethod def compute_matches( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute matches Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. num_anchors_per_level: number of anchors per feature pyramid level num_anchors_per_loc: number of anchors per position Returns: - matrix which contains the similarity from each boxes to each anchor [N, M] - vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] """ raise NotImplementedError
[docs] class ATSSMatcher(Matcher):
[docs] def __init__( self, num_candidates: int = 4, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou, # type: ignore center_in_gt: bool = True, debug: bool = False, ): """ Compute matching based on ATSS https://arxiv.org/abs/1912.02424 `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection` Args: num_candidates: number of positions to select candidates from. Smaller value will result in a higher matcher threshold and less matched candidates. similarity_fn: function for similarity computation between boxes and anchors center_in_gt: If False (default), matched anchor center points do not need to lie withing the ground truth box. Recommend False for small objects. If True, will result in a strict matcher and less matched candidates. debug: if True, will print the matcher threshold in order to tune ``num_candidates`` and ``center_in_gt``. """ super().__init__(similarity_fn=similarity_fn) self.num_candidates = num_candidates self.min_dist = 0.01 self.center_in_gt = center_in_gt self.debug = debug logging.info( f"Running ATSS Matching with num_candidates={self.num_candidates} and center_in_gt {self.center_in_gt}." )
[docs] def compute_matches( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute matches according to ATTS for a single image Adapted from (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss/loss.py#L180-L184) Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. num_anchors_per_level: number of anchors per feature pyramid level num_anchors_per_loc: number of anchors per position Returns: - matrix which contains the similarity from each boxes to each anchor [N, M] - vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] Note: ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D. """ num_gt = boxes.shape[0] num_anchors = anchors.shape[0] distances_, _, anchors_center = boxes_center_distance(boxes, anchors) # num_boxes x anchors distances = convert_to_tensor(distances_) # select candidates based on center distance candidate_idx_list = [] start_idx = 0 for _, apl in enumerate(num_anchors_per_level): end_idx = start_idx + apl * num_anchors_per_loc # topk: total number of candidates per position topk = min(self.num_candidates * num_anchors_per_loc, apl) # torch.topk() does not support float16 cpu, need conversion to float32 or float64 _, idx = distances[:, start_idx:end_idx].to(COMPUTE_DTYPE).topk(topk, dim=1, largest=False) # idx: shape [num_boxes x topk] candidate_idx_list.append(idx + start_idx) start_idx = end_idx # [num_boxes x num_candidates] (index of candidate anchors) candidate_idx = torch.cat(candidate_idx_list, dim=1) match_quality_matrix = self.similarity_fn(boxes, anchors) # [num_boxes x anchors] candidate_ious = match_quality_matrix.gather(1, candidate_idx) # [num_boxes, n_candidates] # corner case, n_candidates<=1 will make iou_std_per_gt NaN if candidate_idx.shape[1] <= 1: matches = -1 * torch.ones((num_anchors,), dtype=torch.long, device=boxes.device) matches[candidate_idx] = 0 return match_quality_matrix, matches # compute adaptive iou threshold iou_mean_per_gt = candidate_ious.mean(dim=1) # [num_boxes] iou_std_per_gt = candidate_ious.std(dim=1) # [num_boxes] iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt # [num_boxes] is_pos = candidate_ious >= iou_thresh_per_gt[:, None] # [num_boxes x n_candidates] if self.debug: print(f"Anchor matcher threshold: {iou_thresh_per_gt}") if self.center_in_gt: # can discard all candidates in case of very small objects :/ # center point of selected anchors needs to lie within the ground truth boxes_idx = ( torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None] .expand_as(candidate_idx) .contiguous() ) # [num_boxes x n_candidates] is_in_gt_ = centers_in_boxes( anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist ) is_in_gt = convert_to_tensor(is_in_gt_) is_pos = is_pos & is_in_gt.view_as(is_pos) # [num_boxes x n_candidates] # in case on anchor is assigned to multiple boxes, use box with highest IoU # TODO: think about a better way to do this for ng in range(num_gt): candidate_idx[ng, :] += ng * num_anchors ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1) index = candidate_idx.view(-1)[is_pos.view(-1)] ious_inf[index] = match_quality_matrix.view(-1)[index] ious_inf = ious_inf.view_as(match_quality_matrix) matched_vals, matches = ious_inf.to(COMPUTE_DTYPE).max(dim=0) matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD return match_quality_matrix, matches
MatcherType = TypeVar("MatcherType", bound=Matcher)