Source code for monai.transforms.utils_pytorch_numpy_unification

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from typing import TypeVar

import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

__all__ = [
    "allclose",
    "moveaxis",
    "in1d",
    "clip",
    "percentile",
    "where",
    "argwhere",
    "argsort",
    "nonzero",
    "floor_divide",
    "unravel_index",
    "unravel_indices",
    "ravel",
    "any_np_pt",
    "maximum",
    "concatenate",
    "cumsum",
    "isfinite",
    "searchsorted",
    "repeat",
    "isnan",
    "ascontiguousarray",
    "stack",
    "mode",
    "unique",
    "max",
    "min",
    "median",
    "mean",
    "std",
]


[docs] def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool: """`np.allclose` with equivalent implementation for torch.""" b, *_ = convert_to_dst_type(b, a, wrap_sequence=True) if isinstance(a, np.ndarray): return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) # type: ignore
[docs] def moveaxis(x: NdarrayOrTensor, src: int | Sequence[int], dst: int | Sequence[int]) -> NdarrayOrTensor: """`moveaxis` for pytorch and numpy""" if isinstance(x, torch.Tensor): return torch.movedim(x, src, dst) # type: ignore return np.moveaxis(x, src, dst)
[docs] def in1d(x, y): """`np.in1d` with equivalent implementation for torch.""" if isinstance(x, np.ndarray): return np.in1d(x, y) return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)
[docs] def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor: """`np.clip` with equivalent implementation for torch.""" result: NdarrayOrTensor if isinstance(a, np.ndarray): result = np.clip(a, a_min, a_max) else: result = torch.clamp(a, a_min, a_max) return result
[docs] def percentile( x: NdarrayOrTensor, q, dim: int | None = None, keepdim: bool = False, **kwargs ) -> NdarrayOrTensor | float | int: """`np.percentile` with equivalent implementation for torch. Pytorch uses `quantile`. For more details please refer to: https://pytorch.org/docs/stable/generated/torch.quantile.html. https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. Args: x: input data. q: percentile to compute (should in range 0 <= q <= 100). dim: the dim along which the percentiles are computed. default is to compute the percentile along a flattened version of the array. keepdim: whether the output data has dim retained or not. kwargs: if `x` is numpy array, additional args for `np.percentile`, more details: https://numpy.org/doc/stable/reference/generated/numpy.percentile.html. Returns: Resulting value (scalar) """ q_np = convert_data_type(q, output_type=np.ndarray, wrap_sequence=True)[0] if ((q_np < 0) | (q_np > 100)).any(): raise ValueError(f"q values must be in [0, 100], got values: {q}.") result: NdarrayOrTensor | float | int if isinstance(x, np.ndarray) or (isinstance(x, torch.Tensor) and torch.numel(x) > 1_000_000): # pytorch#64947 _x = convert_data_type(x, output_type=np.ndarray)[0] result = np.percentile(_x, q_np, axis=dim, keepdims=keepdim, **kwargs) result = convert_to_dst_type(result, x)[0] else: q = convert_to_dst_type(q_np / 100.0, x)[0] result = torch.quantile(x, q, dim=dim, keepdim=keepdim) return result
[docs] def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor: """ Note that `torch.where` may convert y.dtype to x.dtype. """ result: NdarrayOrTensor if isinstance(condition, np.ndarray): if x is not None: result = np.where(condition, x, y) else: result = np.where(condition) # type: ignore else: if x is not None: x = torch.as_tensor(x, device=condition.device) y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) result = torch.where(condition, x, y) else: result = torch.where(condition) # type: ignore return result
[docs] def argwhere(a: NdarrayTensor) -> NdarrayTensor: """`np.argwhere` with equivalent implementation for torch. Args: a: input data. Returns: Indices of elements that are non-zero. Indices are grouped by element. This array will have shape (N, a.ndim) where N is the number of non-zero items. """ if isinstance(a, np.ndarray): return np.argwhere(a) # type: ignore return torch.argwhere(a) # type: ignore
[docs] def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: """`np.argsort` with equivalent implementation for torch. Args: a: the array/tensor to sort. axis: axis along which to sort. Returns: Array/Tensor of indices that sort a along the specified axis. """ if isinstance(a, np.ndarray): return np.argsort(a, axis=axis) # type: ignore return torch.argsort(a, dim=axis) # type: ignore
[docs] def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.nonzero` with equivalent implementation for torch. Args: x: array/tensor. Returns: Index unravelled for given shape """ if isinstance(x, np.ndarray): return np.nonzero(x)[0] return torch.nonzero(x).flatten()
[docs] def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: """`np.floor_divide` with equivalent implementation for torch. As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and before that, use `torch.floor_divide`. Args: a: first array/tensor b: scalar to divide by Returns: Element-wise floor division between two arrays/tensors. """ if isinstance(a, torch.Tensor): if is_module_ver_at_least(torch, (1, 8, 0)): return torch.div(a, b, rounding_mode="floor") return torch.floor_divide(a, b) return np.floor_divide(a, b)
[docs] def unravel_index(idx, shape) -> NdarrayOrTensor: """`np.unravel_index` with equivalent implementation for torch. Args: idx: index to unravel. shape: shape of array/tensor. Returns: Index unravelled for given shape """ if isinstance(idx, torch.Tensor): coord = [] for dim in reversed(shape): coord.append(idx % dim) idx = floor_divide(idx, dim) return torch.stack(coord[::-1]) return np.asarray(np.unravel_index(idx, shape))
[docs] def unravel_indices(idx, shape) -> NdarrayOrTensor: """Computing unravel coordinates from indices. Args: idx: a sequence of indices to unravel. shape: shape of array/tensor. Returns: Stacked indices unravelled for given shape """ lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack return lib_stack([unravel_index(i, shape) for i in idx]) # type: ignore
[docs] def ravel(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.ravel` with equivalent implementation for torch. Args: x: array/tensor to ravel. Returns: Return a contiguous flattened array/tensor. """ if isinstance(x, torch.Tensor): if hasattr(torch, "ravel"): # `ravel` is new in torch 1.8.0 return x.ravel() return x.flatten().contiguous() return np.ravel(x)
[docs] def any_np_pt(x: NdarrayOrTensor, axis: int | Sequence[int]) -> NdarrayOrTensor: """`np.any` with equivalent implementation for torch. For pytorch, convert to boolean for compatibility with older versions. Args: x: input array/tensor. axis: axis to perform `any` over. Returns: Return a contiguous flattened array/tensor. """ if isinstance(x, np.ndarray): return np.any(x, axis) # type: ignore # pytorch can't handle multiple dimensions to `any` so loop across them axis = [axis] if not isinstance(axis, Sequence) else axis for ax in axis: try: x = torch.any(x, ax) except RuntimeError: # older versions of pytorch require the input to be cast to boolean x = torch.any(x.bool(), ax) return x
[docs] def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor: """`np.maximum` with equivalent implementation for torch. Args: a: first array/tensor. b: second array/tensor. Returns: Element-wise maximum between two arrays/tensors. """ if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): return torch.maximum(a, b) return np.maximum(a, b)
[docs] def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> NdarrayOrTensor: """`np.concatenate` with equivalent implementation for torch (`torch.cat`).""" if isinstance(to_cat[0], np.ndarray): return np.concatenate(to_cat, axis, out) # type: ignore return torch.cat(to_cat, dim=axis, out=out) # type: ignore
[docs] def cumsum(a: NdarrayOrTensor, axis=None, **kwargs) -> NdarrayOrTensor: """ `np.cumsum` with equivalent implementation for torch. Args: a: input data to compute cumsum. axis: expected axis to compute cumsum. kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details: https://pytorch.org/docs/stable/generated/torch.cumsum.html. """ if isinstance(a, np.ndarray): return np.cumsum(a, axis) # type: ignore if axis is None: return torch.cumsum(a[:], 0, **kwargs) return torch.cumsum(a, dim=axis, **kwargs)
[docs] def isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.isfinite` with equivalent implementation for torch.""" if not isinstance(x, torch.Tensor): return np.isfinite(x) # type: ignore return torch.isfinite(x)
[docs] def searchsorted(a: NdarrayTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs) -> NdarrayTensor: """ `np.searchsorted` with equivalent implementation for torch. Args: a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension. v: containing the search values. right: if False, return the first suitable location that is found, if True, return the last such index. sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order. kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details: https://pytorch.org/docs/stable/generated/torch.searchsorted.html. """ side = "right" if right else "left" if isinstance(a, np.ndarray): return np.searchsorted(a, v, side, sorter) # type: ignore return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore
[docs] def repeat(a: NdarrayOrTensor, repeats: int, axis: int | None = None, **kwargs) -> NdarrayOrTensor: """ `np.repeat` with equivalent implementation for torch (`repeat_interleave`). Args: a: input data to repeat. repeats: number of repetitions for each element, repeats is broadcast to fit the shape of the given axis. axis: axis along which to repeat values. kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details: https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html. """ if isinstance(a, np.ndarray): return np.repeat(a, repeats, axis) return torch.repeat_interleave(a, repeats, dim=axis, **kwargs)
[docs] def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.isnan` with equivalent implementation for torch. Args: x: array/tensor. """ if isinstance(x, np.ndarray): return np.isnan(x) # type: ignore return torch.isnan(x)
T = TypeVar("T")
[docs] def ascontiguousarray(x: NdarrayTensor | T, **kwargs) -> NdarrayOrTensor | T: """`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`). Args: x: array/tensor. kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details: https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html. """ if isinstance(x, np.ndarray): if x.ndim == 0: return x return np.ascontiguousarray(x) if isinstance(x, torch.Tensor): return x.contiguous(**kwargs) return x
[docs] def stack(x: Sequence[NdarrayTensor], dim: int) -> NdarrayTensor: """`np.stack` with equivalent implementation for torch. Args: x: array/tensor. dim: dimension along which to perform the stack (referred to as `axis` by numpy). """ if isinstance(x[0], np.ndarray): return np.stack(x, dim) # type: ignore return torch.stack(x, dim) # type: ignore
[docs] def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor: """`torch.mode` with equivalent implementation for numpy. Args: x: array/tensor. dim: dimension along which to perform `mode` (referred to as `axis` by numpy). to_long: convert input to long before performing mode. """ dtype = torch.int64 if to_long else None x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) o_t = torch.mode(x_t, dim).values o, *_ = convert_to_dst_type(o_t, x) return o
[docs] def unique(x: NdarrayTensor, **kwargs) -> NdarrayTensor: """`torch.unique` with equivalent implementation for numpy. Args: x: array/tensor. """ return np.unique(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.unique(x, **kwargs) # type: ignore
def linalg_inv(x: NdarrayTensor) -> NdarrayTensor: """`torch.linalg.inv` with equivalent implementation for numpy. Args: x: array/tensor. """ if isinstance(x, torch.Tensor) and hasattr(torch, "inverse"): # pytorch 1.7.0 return torch.inverse(x) # type: ignore return torch.linalg.inv(x) if isinstance(x, torch.Tensor) else np.linalg.inv(x) # type: ignore
[docs] def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.max` with equivalent implementation for numpy Args: x: array/tensor. Returns: the maximum of x. """ ret: NdarrayTensor if dim is None: ret = np.max(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.max(x, **kwargs) # type: ignore else: if isinstance(x, (np.ndarray, list)): ret = np.max(x, axis=dim, **kwargs) else: ret = torch.max(x, int(dim), **kwargs) # type: ignore return ret
[docs] def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.mean` with equivalent implementation for numpy Args: x: array/tensor. Returns: the mean of x """ ret: NdarrayTensor if dim is None: ret = np.mean(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.mean(x, **kwargs) # type: ignore else: if isinstance(x, (np.ndarray, list)): ret = np.mean(x, axis=dim, **kwargs) else: ret = torch.mean(x, int(dim), **kwargs) # type: ignore return ret
[docs] def median(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.median` with equivalent implementation for numpy Args: x: array/tensor. Returns the median of x. """ ret: NdarrayTensor if dim is None: ret = np.median(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.median(x, **kwargs) # type: ignore else: if isinstance(x, (np.ndarray, list)): ret = np.median(x, axis=dim, **kwargs) else: ret = torch.median(x, int(dim), **kwargs) # type: ignore return ret
[docs] def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.min` with equivalent implementation for numpy Args: x: array/tensor. Returns: the minimum of x. """ ret: NdarrayTensor if dim is None: ret = np.min(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.min(x, **kwargs) # type: ignore else: if isinstance(x, (np.ndarray, list)): ret = np.min(x, axis=dim, **kwargs) else: ret = torch.min(x, int(dim), **kwargs) # type: ignore return ret
[docs] def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: """`torch.std` with equivalent implementation for numpy Args: x: array/tensor. Returns: the standard deviation of x. """ ret: NdarrayTensor if dim is None: ret = np.std(x) if isinstance(x, (np.ndarray, list)) else torch.std(x, unbiased) # type: ignore else: if isinstance(x, (np.ndarray, list)): ret = np.std(x, axis=dim) else: ret = torch.std(x, int(dim), unbiased) # type: ignore return ret
def sum(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: """`torch.sum` with equivalent implementation for numpy Args: x: array/tensor. Returns: the sum of x. """ ret: NdarrayTensor if dim is None: ret = np.sum(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.sum(x, **kwargs) # type: ignore else: if isinstance(x, (np.ndarray, list)): ret = np.sum(x, axis=dim, **kwargs) else: ret = torch.sum(x, int(dim), **kwargs) # type: ignore return ret