Source code for monai.apps.detection.transforms.box_ops

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from copy import deepcopy

import numpy as np
import torch

from monai.config.type_definitions import DtypeLike, NdarrayOrTensor, NdarrayTensor
from monai.data.box_utils import COMPUTE_DTYPE, TO_REMOVE, get_spatial_dims
from monai.transforms import Resize
from monai.transforms.utils import create_scale
from monai.utils import look_up_option
from monai.utils.misc import ensure_tuple, ensure_tuple_rep
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type


def _apply_affine_to_points(points: torch.Tensor, affine: torch.Tensor, include_shift: bool = True) -> torch.Tensor:
    """
    This internal function applies affine matrices to the point coordinate

    Args:
        points: point coordinates, Nx2 or Nx3 torch tensor or ndarray, representing [x, y] or [x, y, z]
        affine: affine matrix to be applied to the point coordinates, sized (spatial_dims+1,spatial_dims+1)
        include_shift: default True, whether the function apply translation (shift) in the affine transform

    Returns:
        transformed point coordinates, with same data type as ``points``, does not share memory with ``points``
    """

    spatial_dims = get_spatial_dims(points=points)

    # compute new points
    if include_shift:
        # append 1 to form Nx(spatial_dims+1) vector, then transpose
        points_affine = torch.cat(
            [points, torch.ones(points.shape[0], 1, device=points.device, dtype=points.dtype)], dim=1
        ).transpose(0, 1)
        # apply affine
        points_affine = torch.matmul(affine, points_affine)
        # remove appended 1 and transpose back
        points_affine = points_affine[:spatial_dims, :].transpose(0, 1)
    else:
        points_affine = points.transpose(0, 1)
        points_affine = torch.matmul(affine[:spatial_dims, :spatial_dims], points_affine)
        points_affine = points_affine.transpose(0, 1)

    return points_affine


[docs] def apply_affine_to_boxes(boxes: NdarrayTensor, affine: NdarrayOrTensor) -> NdarrayTensor: """ This function applies affine matrices to the boxes Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode affine: affine matrix to be applied to the box coordinates, sized (spatial_dims+1,spatial_dims+1) Returns: returned affine transformed boxes, with same data type as ``boxes``, does not share memory with ``boxes`` """ # convert numpy to tensor if needed boxes_t, *_ = convert_data_type(boxes, torch.Tensor) # some operation does not support torch.float16 # convert to float32 boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE) affine_t, *_ = convert_to_dst_type(src=affine, dst=boxes_t) spatial_dims = get_spatial_dims(boxes=boxes_t) # affine transform left top and bottom right points # might flipped, thus lt may not be left top any more lt: torch.Tensor = _apply_affine_to_points(boxes_t[:, :spatial_dims], affine_t, include_shift=True) rb: torch.Tensor = _apply_affine_to_points(boxes_t[:, spatial_dims:], affine_t, include_shift=True) # make sure lt_new is left top, and rb_new is bottom right lt_new, _ = torch.min(torch.stack([lt, rb], dim=2), dim=2) rb_new, _ = torch.max(torch.stack([lt, rb], dim=2), dim=2) boxes_t_affine = torch.cat([lt_new, rb_new], dim=1) # convert tensor back to numpy if needed boxes_affine: NdarrayOrTensor boxes_affine, *_ = convert_to_dst_type(src=boxes_t_affine, dst=boxes) return boxes_affine # type: ignore[return-value]
[docs] def zoom_boxes(boxes: NdarrayTensor, zoom: Sequence[float] | float) -> NdarrayTensor: """ Zoom boxes Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode zoom: The zoom factor along the spatial axes. If a float, zoom is the same for each spatial axis. If a sequence, zoom should contain one value for each spatial axis. Returns: zoomed boxes, with same data type as ``boxes``, does not share memory with ``boxes`` Example: .. code-block:: python boxes = torch.ones(1,4) zoom_boxes(boxes, zoom=[0.5,2.2]) # will return tensor([[0.5, 2.2, 0.5, 2.2]]) """ spatial_dims = get_spatial_dims(boxes=boxes) # generate affine transform corresponding to ``zoom`` affine = create_scale(spatial_dims=spatial_dims, scaling_factor=zoom) return apply_affine_to_boxes(boxes=boxes, affine=affine)
[docs] def resize_boxes( boxes: NdarrayOrTensor, src_spatial_size: Sequence[int] | int, dst_spatial_size: Sequence[int] | int ) -> NdarrayOrTensor: """ Resize boxes when the corresponding image is resized Args: boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` src_spatial_size: source image spatial size. dst_spatial_size: target image spatial size. Returns: resized boxes, with same data type as ``boxes``, does not share memory with ``boxes`` Example: .. code-block:: python boxes = torch.ones(1,4) src_spatial_size = [100, 100] dst_spatial_size = [128, 256] resize_boxes(boxes, src_spatial_size, dst_spatial_size) # will return tensor([[1.28, 2.56, 1.28, 2.56]]) """ spatial_dims: int = get_spatial_dims(boxes=boxes) src_spatial_size = ensure_tuple_rep(src_spatial_size, spatial_dims) dst_spatial_size = ensure_tuple_rep(dst_spatial_size, spatial_dims) zoom = [dst_spatial_size[axis] / float(src_spatial_size[axis]) for axis in range(spatial_dims)] return zoom_boxes(boxes=boxes, zoom=zoom)
[docs] def flip_boxes( boxes: NdarrayTensor, spatial_size: Sequence[int] | int, flip_axes: Sequence[int] | int | None = None ) -> NdarrayTensor: """ Flip boxes when the corresponding image is flipped Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` spatial_size: image spatial size. flip_axes: spatial axes along which to flip over. Default is None. The default `axis=None` will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. Returns: flipped boxes, with same data type as ``boxes``, does not share memory with ``boxes`` """ spatial_dims: int = get_spatial_dims(boxes=boxes) spatial_size = ensure_tuple_rep(spatial_size, spatial_dims) if flip_axes is None: flip_axes = tuple(range(0, spatial_dims)) flip_axes = ensure_tuple(flip_axes) # flip box _flip_boxes: NdarrayTensor = boxes.clone() if isinstance(boxes, torch.Tensor) else deepcopy(boxes) # type: ignore[assignment] for axis in flip_axes: _flip_boxes[:, axis + spatial_dims] = spatial_size[axis] - boxes[:, axis] - TO_REMOVE _flip_boxes[:, axis] = spatial_size[axis] - boxes[:, axis + spatial_dims] - TO_REMOVE return _flip_boxes
[docs] def convert_box_to_mask( boxes: NdarrayOrTensor, labels: NdarrayOrTensor, spatial_size: Sequence[int] | int, bg_label: int = -1, ellipse_mask: bool = False, ) -> NdarrayOrTensor: """ Convert box to int16 mask image, which has the same size with the input image. Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``. labels: classification foreground(fg) labels corresponding to `boxes`, dtype should be int, sized (N,). spatial_size: image spatial size. bg_label: background labels for the output mask image, make sure it is smaller than any fg labels. ellipse_mask: bool. - If True, it assumes the object shape is close to ellipse or ellipsoid. - If False, it assumes the object shape is close to rectangle or cube and well occupies the bounding box. - If the users are going to apply random rotation as data augmentation, we suggest setting ellipse_mask=True See also Kalra et al. "Towards Rotation Invariance in Object Detection", ICCV 2021. Return: - int16 array, sized (num_box, H, W). Each channel represents a box. The foreground region in channel c has intensity of labels[c]. The background intensity is bg_label. """ spatial_dims: int = get_spatial_dims(boxes=boxes) spatial_size = ensure_tuple_rep(spatial_size, spatial_dims) # if no box, return empty mask if labels.shape[0] == 0: boxes_mask_np = np.ones((1,) + spatial_size, dtype=np.int16) * np.int16(bg_label) boxes_mask, *_ = convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16) return boxes_mask # bg_label should be smaller than labels if bg_label >= min(labels): raise ValueError( f"bg_label should be smaller than any foreground box labels.\n" f"min(labels)={min(labels)}, while bg_label={bg_label}" ) if labels.shape[0] != boxes.shape[0]: raise ValueError("Number of labels should equal to number of boxes.") # allocate memory for boxes_mask_np boxes_mask_np = np.ones((labels.shape[0],) + spatial_size, dtype=np.int16) * np.int16(bg_label) boxes_np: np.ndarray = convert_data_type(boxes, np.ndarray, dtype=np.int32)[0] if np.any(boxes_np[:, spatial_dims:] > np.array(spatial_size)): raise ValueError("Some boxes are larger than the image.") labels_np, *_ = convert_to_dst_type(src=labels, dst=boxes_np) for b in range(boxes_np.shape[0]): # generate a foreground mask box_size = [boxes_np[b, axis + spatial_dims] - boxes_np[b, axis] for axis in range(spatial_dims)] if ellipse_mask: # initialize a square/cube mask max_box_size = max(box_size) # max of box w/h/d radius = max_box_size / 2.0 center = (max_box_size - 1) / 2.0 boxes_only_mask = np.ones([max_box_size] * spatial_dims, dtype=np.int16) * np.int16(bg_label) # apply label intensity to generate circle/ball foreground ranges = tuple(slice(0, max_box_size) for _ in range(spatial_dims)) dist_from_center = sum((grid - center) ** 2 for grid in np.ogrid[ranges]) boxes_only_mask[dist_from_center <= radius**2] = np.int16(labels_np[b]) # squeeze it to a ellipse/ellipsoid mask resizer = Resize(spatial_size=box_size, mode="nearest", anti_aliasing=False) boxes_only_mask = resizer(boxes_only_mask[None])[0] # type: ignore else: # generate a rect mask boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b]) # apply to global mask slicing = [b] slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore boxes_mask_np[tuple(slicing)] = boxes_only_mask return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]
[docs] def convert_mask_to_box( boxes_mask: NdarrayOrTensor, bg_label: int = -1, box_dtype: DtypeLike | torch.dtype = torch.float32, label_dtype: DtypeLike | torch.dtype = torch.long, ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Convert int16 mask image to box, which has the same size with the input image Args: boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box. The foreground region in channel c has intensity of labels[c]. The background intensity is bg_label. bg_label: background labels for the boxes_mask box_dtype: output dtype for boxes label_dtype: output dtype for labels Return: - bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``. - classification foreground(fg) labels, dtype should be int, sized (N,). """ look_up_option(len(boxes_mask.shape), [3, 4]) spatial_size = list(boxes_mask.shape[1:]) spatial_dims = get_spatial_dims(spatial_size=spatial_size) boxes_mask_np, *_ = convert_data_type(boxes_mask, np.ndarray) boxes_list = [] labels_list = [] for b in range(boxes_mask_np.shape[0]): fg_indices = np.nonzero(boxes_mask_np[b, ...] - bg_label) if fg_indices[0].shape[0] == 0: continue boxes_b = [] for fd_i in fg_indices: boxes_b.append(min(fd_i)) # top left corner for fd_i in fg_indices: boxes_b.append(max(fd_i) + 1 - TO_REMOVE) # bottom right corner boxes_list.append(boxes_b) if spatial_dims == 2: labels_list.append(boxes_mask_np[b, fg_indices[0][0], fg_indices[1][0]]) if spatial_dims == 3: labels_list.append(boxes_mask_np[b, fg_indices[0][0], fg_indices[1][0], fg_indices[2][0]]) if len(boxes_list) == 0: boxes_np, labels_np = np.zeros([0, 2 * spatial_dims]), np.zeros([0]) else: boxes_np, labels_np = np.asarray(boxes_list), np.asarray(labels_list) boxes, *_ = convert_to_dst_type(src=boxes_np, dst=boxes_mask, dtype=box_dtype) labels, *_ = convert_to_dst_type(src=labels_np, dst=boxes_mask, dtype=label_dtype) return boxes, labels
[docs] def select_labels( labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor, keep: NdarrayOrTensor ) -> tuple | NdarrayOrTensor: """ For element in labels, select indices keep from it. Args: labels: Sequence of array. Each element represents classification labels or scores corresponding to ``boxes``, sized (N,). keep: the indices to keep, same length with each element in labels. Return: selected labels, does not share memory with original labels. """ labels_tuple = ensure_tuple(labels, True) labels_select_list = [] keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0] for item in labels_tuple: labels_t: torch.Tensor = convert_data_type(item, torch.Tensor)[0] labels_t = labels_t[keep_t, ...] labels_select_list.append(convert_to_dst_type(src=labels_t, dst=item)[0]) if isinstance(labels, (torch.Tensor, np.ndarray)): return labels_select_list[0] # type: ignore return tuple(labels_select_list)
[docs] def swapaxes_boxes(boxes: NdarrayTensor, axis1: int, axis2: int) -> NdarrayTensor: """ Interchange two axes of boxes. Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` axis1: First axis. axis2: Second axis. Returns: boxes with two axes interchanged. """ spatial_dims: int = get_spatial_dims(boxes=boxes) if isinstance(boxes, torch.Tensor): boxes_swap = boxes.clone() else: boxes_swap = deepcopy(boxes) # type: ignore boxes_swap[:, [axis1, axis2]] = boxes_swap[:, [axis2, axis1]] boxes_swap[:, [spatial_dims + axis1, spatial_dims + axis2]] = boxes_swap[ :, [spatial_dims + axis2, spatial_dims + axis1] ] return boxes_swap # type: ignore[return-value]
[docs] def rot90_boxes( boxes: NdarrayTensor, spatial_size: Sequence[int] | int, k: int = 1, axes: tuple[int, int] = (0, 1) ) -> NdarrayTensor: """ Rotate boxes by 90 degrees in the plane specified by axes. Rotation direction is from the first towards the second axis. Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` spatial_size: image spatial size. k : number of times the array is rotated by 90 degrees. axes: (2,) array_like The array is rotated in the plane defined by the axes. Axes must be different. Returns: A rotated view of `boxes`. Notes: ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is the reverse of ``rot90_boxes(boxes, spatial_size, k=1, axes=(0,1))`` ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is equivalent to ``rot90_boxes(boxes, spatial_size, k=-1, axes=(0,1))`` """ spatial_dims: int = get_spatial_dims(boxes=boxes) spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims)) axes = ensure_tuple(axes) # type: ignore if len(axes) != 2: raise ValueError("len(axes) must be 2.") if axes[0] == axes[1] or abs(axes[0] - axes[1]) == spatial_dims: raise ValueError("Axes must be different.") if axes[0] >= spatial_dims or axes[0] < -spatial_dims or axes[1] >= spatial_dims or axes[1] < -spatial_dims: raise ValueError(f"Axes={axes} out of range for array of ndim={spatial_dims}.") k %= 4 if k == 0: return boxes if k == 2: return flip_boxes(flip_boxes(boxes, spatial_size_, axes[0]), spatial_size_, axes[1]) if k == 1: boxes_ = flip_boxes(boxes, spatial_size_, axes[1]) return swapaxes_boxes(boxes_, axes[0], axes[1]) else: # k == 3 boxes_ = swapaxes_boxes(boxes, axes[0], axes[1]) spatial_size_[axes[0]], spatial_size_[axes[1]] = spatial_size_[axes[1]], spatial_size_[axes[0]] return flip_boxes(boxes_, spatial_size_, axes[1])