# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Sequence, Union
import numpy as np
import torch
from monai.config.type_definitions import NdarrayOrTensor
from monai.utils.misc import is_module_ver_at_least
__all__ = [
"moveaxis",
"in1d",
"clip",
"percentile",
"where",
"nonzero",
"floor_divide",
"unravel_index",
"unravel_indices",
"ravel",
"any_np_pt",
"maximum",
"concatenate",
"cumsum",
"isfinite",
"searchsorted",
"repeat",
"isnan",
]
[docs]def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor:
"""`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8"""
if isinstance(x, torch.Tensor):
if hasattr(torch, "moveaxis"):
return torch.moveaxis(x, src, dst)
return _moveaxis_with_permute(x, src, dst) # type: ignore
if isinstance(x, np.ndarray):
return np.moveaxis(x, src, dst)
raise RuntimeError()
def _moveaxis_with_permute(x, src, dst):
# get original indices
indices = list(range(x.ndim))
# make src and dst positive
if src < 0:
src = len(indices) + src
if dst < 0:
dst = len(indices) + dst
# remove desired index and insert it in new position
indices.pop(src)
indices.insert(dst, src)
return x.permute(indices)
[docs]def in1d(x, y):
"""`np.in1d` with equivalent implementation for torch."""
if isinstance(x, np.ndarray):
return np.in1d(x, y)
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)
[docs]def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
"""`np.clip` with equivalent implementation for torch."""
result: NdarrayOrTensor
if isinstance(a, np.ndarray):
result = np.clip(a, a_min, a_max)
else:
result = torch.clamp(a, a_min, a_max)
return result
[docs]def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
"""`np.percentile` with equivalent implementation for torch.
Pytorch uses `quantile`, but this functionality is only available from v1.7.
For earlier methods, we calculate it ourselves. This doesn't do interpolation,
so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``.
Args:
x: input data
q: percentile to compute (should in range 0 <= q <= 100)
Returns:
Resulting value (scalar)
"""
if np.isscalar(q):
if not 0 <= q <= 100:
raise ValueError
elif any(q < 0) or any(q > 100):
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q)
else:
q = torch.tensor(q, device=x.device)
if hasattr(torch, "quantile"):
result = torch.quantile(x, q / 100.0)
else:
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
k = 1 + (0.01 * q * (x.numel() - 1)).round().int()
if k.numel() > 1:
r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k]
result = torch.tensor(r, device=x.device)
else:
result = x.view(-1).kthvalue(int(k)).values.item()
return result
[docs]def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor:
"""
Note that `torch.where` may convert y.dtype to x.dtype.
"""
result: NdarrayOrTensor
if isinstance(condition, np.ndarray):
if x is not None:
result = np.where(condition, x, y)
else:
result = np.where(condition)
else:
if x is not None:
x = torch.as_tensor(x, device=condition.device)
y = torch.as_tensor(y, device=condition.device, dtype=x.dtype)
result = torch.where(condition, x, y)
else:
result = torch.where(condition) # type: ignore
return result
[docs]def nonzero(x: NdarrayOrTensor):
"""`np.nonzero` with equivalent implementation for torch.
Args:
x: array/tensor
Returns:
Index unravelled for given shape
"""
if isinstance(x, np.ndarray):
return np.nonzero(x)[0]
return torch.nonzero(x).flatten()
[docs]def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
"""`np.floor_divide` with equivalent implementation for torch.
As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and
before that, use `torch.floor_divide`.
Args:
a: first array/tensor
b: scalar to divide by
Returns:
Element-wise floor division between two arrays/tensors.
"""
if isinstance(a, torch.Tensor):
if is_module_ver_at_least(torch, (1, 8, 0)):
return torch.div(a, b, rounding_mode="floor")
return torch.floor_divide(a, b)
return np.floor_divide(a, b)
[docs]def unravel_index(idx, shape):
"""`np.unravel_index` with equivalent implementation for torch.
Args:
idx: index to unravel
shape: shape of array/tensor
Returns:
Index unravelled for given shape
"""
if isinstance(idx, torch.Tensor):
coord = []
for dim in reversed(shape):
coord.append(idx % dim)
idx = floor_divide(idx, dim)
return torch.stack(coord[::-1])
return np.asarray(np.unravel_index(idx, shape))
[docs]def unravel_indices(idx, shape):
"""Computing unravel coordinates from indices.
Args:
idx: a sequence of indices to unravel
shape: shape of array/tensor
Returns:
Stacked indices unravelled for given shape
"""
lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack
return lib_stack([unravel_index(i, shape) for i in idx])
[docs]def ravel(x: NdarrayOrTensor):
"""`np.ravel` with equivalent implementation for torch.
Args:
x: array/tensor to ravel
Returns:
Return a contiguous flattened array/tensor.
"""
if isinstance(x, torch.Tensor):
if hasattr(torch, "ravel"):
return x.ravel()
return x.flatten().contiguous()
return np.ravel(x)
[docs]def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]):
"""`np.any` with equivalent implementation for torch.
For pytorch, convert to boolean for compatibility with older versions.
Args:
x: input array/tensor
axis: axis to perform `any` over
Returns:
Return a contiguous flattened array/tensor.
"""
if isinstance(x, np.ndarray):
return np.any(x, axis)
# pytorch can't handle multiple dimensions to `any` so loop across them
axis = [axis] if not isinstance(axis, Sequence) else axis
for ax in axis:
try:
x = torch.any(x, ax)
except RuntimeError:
# older versions of pytorch require the input to be cast to boolean
x = torch.any(x.bool(), ax)
return x
[docs]def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.maximum` with equivalent implementation for torch.
`torch.maximum` only available from pt>1.6, else use `torch.stack` and `torch.max`.
Args:
a: first array/tensor
b: second array/tensor
Returns:
Element-wise maximum between two arrays/tensors.
"""
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
# is torch and has torch.maximum (pt>1.6)
if hasattr(torch, "maximum"):
return torch.maximum(a, b)
return torch.stack((a, b)).max(dim=0)[0]
return np.maximum(a, b)
[docs]def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> NdarrayOrTensor:
"""`np.concatenate` with equivalent implementation for torch (`torch.cat`)."""
if isinstance(to_cat[0], np.ndarray):
return np.concatenate(to_cat, axis, out) # type: ignore
return torch.cat(to_cat, dim=axis, out=out) # type: ignore
[docs]def cumsum(a: NdarrayOrTensor, axis=None):
"""`np.cumsum` with equivalent implementation for torch."""
if isinstance(a, np.ndarray):
return np.cumsum(a, axis)
if axis is None:
return torch.cumsum(a[:], 0)
return torch.cumsum(a, dim=axis)
[docs]def isfinite(x):
"""`np.isfinite` with equivalent implementation for torch."""
if not isinstance(x, torch.Tensor):
return np.isfinite(x)
return torch.isfinite(x)
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None):
side = "right" if right else "left"
if isinstance(a, np.ndarray):
return np.searchsorted(a, v, side, sorter) # type: ignore
return torch.searchsorted(a, v, right=right) # type: ignore
[docs]def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None):
"""`np.repeat` with equivalent implementation for torch (`repeat_interleave`)."""
if isinstance(a, np.ndarray):
return np.repeat(a, repeats, axis)
return torch.repeat_interleave(a, repeats, dim=axis)
[docs]def isnan(x: NdarrayOrTensor):
"""`np.isnan` with equivalent implementation for torch.
Args:
x: array/tensor
"""
if isinstance(x, np.ndarray):
return np.isnan(x)
return torch.isnan(x)