# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
# which has the following license...
# https://github.com/pytorch/vision/blob/main/LICENSE
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
"""
from __future__ import annotations
from collections.abc import Callable
import torch
from torch import Tensor
from monai.data.box_utils import batched_nms, box_iou, clip_boxes_to_image
from monai.transforms.utils_pytorch_numpy_unification import floor_divide
[docs]
class BoxSelector:
"""
Box selector which selects the predicted boxes.
The box selection is performed with the following steps:
#. For each level, discard boxes with scores less than self.score_thresh.
#. For each level, keep boxes with top self.topk_candidates_per_level scores.
#. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
#. For the whole image, keep boxes with top self.detections_per_img scores.
Args:
apply_sigmoid: whether to apply sigmoid to get scores from classification logits
score_thresh: no box with scores less than score_thresh will be kept
topk_candidates_per_level: max number of boxes to keep for each level
nms_thresh: box overlapping threshold for NMS
detections_per_img: max number of boxes to keep for each image
Example:
.. code-block:: python
input_param = {
"apply_sigmoid": True,
"score_thresh": 0.1,
"topk_candidates_per_level": 2,
"nms_thresh": 0.1,
"detections_per_img": 5,
}
box_selector = BoxSelector(**input_param)
boxes = [torch.randn([3,6]), torch.randn([7,6])]
logits = [torch.randn([3,3]), torch.randn([7,3])]
spatial_size = (8,8,8)
selected_boxes, selected_scores, selected_labels = box_selector.select_boxes_per_image(
boxes, logits, spatial_size
)
"""
def __init__(
self,
box_overlap_metric: Callable = box_iou,
apply_sigmoid: bool = True,
score_thresh: float = 0.05,
topk_candidates_per_level: int = 1000,
nms_thresh: float = 0.5,
detections_per_img: int = 300,
):
self.box_overlap_metric = box_overlap_metric
self.apply_sigmoid = apply_sigmoid
self.score_thresh = score_thresh
self.topk_candidates_per_level = topk_candidates_per_level
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
[docs]
def select_top_score_idx_per_level(self, logits: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""
Select indices with highest scores.
The indices selection is performed with the following steps:
#. If self.apply_sigmoid, get scores by applying sigmoid to logits. Otherwise, use logits as scores.
#. Discard indices with scores less than self.score_thresh
#. Keep indices with top self.topk_candidates_per_level scores
Args:
logits: predicted classification logits, Tensor sized (N, num_classes)
Return:
- topk_idxs: selected M indices, Tensor sized (M, )
- selected_scores: selected M scores, Tensor sized (M, )
- selected_labels: selected M labels, Tensor sized (M, )
"""
num_classes = logits.shape[-1]
# apply sigmoid to classification logits if asked
if self.apply_sigmoid:
scores = torch.sigmoid(logits.to(torch.float32)).flatten()
else:
scores = logits.flatten()
# remove low scoring boxes
keep_idxs = scores > self.score_thresh
scores = scores[keep_idxs]
flatten_topk_idxs = torch.where(keep_idxs)[0]
# keep only topk scoring predictions
num_topk = min(self.topk_candidates_per_level, flatten_topk_idxs.size(0))
selected_scores, idxs = scores.to(torch.float32).topk(
num_topk
) # half precision not implemented for cpu float16
flatten_topk_idxs = flatten_topk_idxs[idxs]
selected_labels = flatten_topk_idxs % num_classes
topk_idxs = floor_divide(flatten_topk_idxs, num_classes)
return topk_idxs, selected_scores, selected_labels # type: ignore
[docs]
def select_boxes_per_image(
self, boxes_list: list[Tensor], logits_list: list[Tensor], spatial_size: list[int] | tuple[int]
) -> tuple[Tensor, Tensor, Tensor]:
"""
Postprocessing to generate detection result from classification logits and boxes.
The box selection is performed with the following steps:
#. For each level, discard boxes with scores less than self.score_thresh.
#. For each level, keep boxes with top self.topk_candidates_per_level scores.
#. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
#. For the whole image, keep boxes with top self.detections_per_img scores.
Args:
boxes_list: list of predicted boxes from a single image,
each element i is a Tensor sized (N_i, 2*spatial_dims)
logits_list: list of predicted classification logits from a single image,
each element i is a Tensor sized (N_i, num_classes)
spatial_size: spatial size of the image
Return:
- selected boxes, Tensor sized (P, 2*spatial_dims)
- selected_scores, Tensor sized (P, )
- selected_labels, Tensor sized (P, )
"""
if len(boxes_list) != len(logits_list):
raise ValueError(
"len(boxes_list) should equal to len(logits_list). "
f"Got len(boxes_list)={len(boxes_list)}, len(logits_list)={len(logits_list)}"
)
image_boxes = []
image_scores = []
image_labels = []
boxes_dtype = boxes_list[0].dtype
logits_dtype = logits_list[0].dtype
for boxes_per_level, logits_per_level in zip(boxes_list, logits_list):
# select topk boxes for each level
topk_idxs: Tensor
topk_idxs, scores_per_level, labels_per_level = self.select_top_score_idx_per_level(logits_per_level)
boxes_per_level = boxes_per_level[topk_idxs]
keep: Tensor
boxes_per_level, keep = clip_boxes_to_image( # type: ignore
boxes_per_level, spatial_size, remove_empty=True
)
image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level[keep])
image_labels.append(labels_per_level[keep])
image_boxes_t: Tensor = torch.cat(image_boxes, dim=0)
image_scores_t: Tensor = torch.cat(image_scores, dim=0)
image_labels_t: Tensor = torch.cat(image_labels, dim=0)
# non-maximum suppression on detected boxes from all levels
keep_t: Tensor = batched_nms( # type: ignore
image_boxes_t,
image_scores_t,
image_labels_t,
self.nms_thresh,
box_overlap_metric=self.box_overlap_metric,
max_proposals=self.detections_per_img,
)
selected_boxes = image_boxes_t[keep_t].to(boxes_dtype)
selected_scores = image_scores_t[keep_t].to(logits_dtype)
selected_labels = image_labels_t[keep_t]
return selected_boxes, selected_scores, selected_labels