# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of "functional" transforms for spatial operations.
"""
from __future__ import annotations
import math
import warnings
from enum import Enum
import numpy as np
import torch
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.transforms.croppad.array import ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import (
LazyAttr,
TraceKeys,
convert_to_dst_type,
convert_to_numpy,
convert_to_tensor,
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
optional_import,
)
nib, has_nib = optional_import("nibabel")
cupy, _ = optional_import("cupy")
cupy_ndi, _ = optional_import("cupyx.scipy.ndimage")
np_ndi, _ = optional_import("scipy.ndimage")
__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"]
def _maybe_new_metatensor(img, dtype=None, device=None):
"""create a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensor"""
return convert_to_tensor(
img.as_tensor() if isinstance(img, MetaTensor) else img,
dtype=dtype,
device=device,
track_meta=get_track_meta(),
wrap_sequence=True,
)
[docs]
def spatial_resample(
img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy, transform_info
) -> torch.Tensor:
"""
Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be resampled, assuming `img` is channel-first.
dst_affine: target affine matrix, if None, use the input affine matrix, effectively no resampling.
spatial_size: output spatial size, if the component is ``-1``, use the corresponding input spatial size.
mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
and the value represents the order of the spline interpolation.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When `mode` is an integer, using numpy/cupy backends, this argument accepts
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype_pt: data `dtype` for resampling computation.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4)
img = convert_to_tensor(data=img, track_meta=get_track_meta())
# ensure spatial rank is <= 3
spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3)
if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None:
spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size
src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64)
dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine
dst_affine = convert_to_dst_type(dst_affine, src_affine)[0]
if not isinstance(dst_affine, torch.Tensor):
raise ValueError(f"dst_affine should be a torch.Tensor, got {type(dst_affine)}")
in_spatial_size = torch.tensor(original_spatial_shape[:spatial_rank])
if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size
spatial_size = in_spatial_size
elif spatial_size is None and spatial_rank > 1: # auto spatial size
spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore
spatial_size = torch.tensor(
fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size, lambda x: x >= 0)
)
extra_info = {
"dtype": str(dtype_pt)[6:], # remove "torch": torch.float32 -> float32
"mode": mode.value if isinstance(mode, Enum) else mode,
"padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode,
"align_corners": align_corners if align_corners is not None else TraceKeys.NONE,
"src_affine": src_affine,
}
try:
_s = convert_to_numpy(src_affine)
_d = convert_to_numpy(dst_affine)
xform = np.eye(spatial_rank + 1) if spatial_rank < 2 else np.linalg.solve(_s, _d)
except (np.linalg.LinAlgError, RuntimeError) as e:
raise ValueError(f"src affine is not invertible {_s}, {_d}.") from e
xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=torch.float64)
affine_unchanged = (
allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size)
) or (allclose(xform, np.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size))
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=spatial_size,
affine=None if affine_unchanged and not lazy else xform,
extra_info=extra_info,
orig_size=original_spatial_shape,
transform_info=transform_info,
lazy=lazy,
)
if lazy:
out = _maybe_new_metatensor(img)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
if affine_unchanged:
# no significant change or lazy change, return original image
out = _maybe_new_metatensor(img, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore
# drop current meta first
img = img.as_tensor() if isinstance(img, MetaTensor) else img
im_size = list(img.shape)
chns, in_sp_size, additional_dims = im_size[0], im_size[1 : spatial_rank + 1], im_size[spatial_rank + 1 :]
if additional_dims:
xform_shape = [-1] + in_sp_size
img = img.reshape(xform_shape)
img = img.to(dtype_pt)
if isinstance(mode, int) or USE_COMPILED:
dst_xform = create_translate(spatial_rank, [float(d - 1) / 2 for d in spatial_size])
xform = xform @ convert_to_dst_type(dst_xform, xform)[0]
affine_xform = monai.transforms.Affine(
affine=xform,
spatial_size=spatial_size,
normalized=True,
image_only=True,
dtype=dtype_pt,
align_corners=align_corners,
)
with affine_xform.trace_transform(False):
img = affine_xform(img, mode=mode, padding_mode=padding_mode)
else:
_, _m, _p, _ = resolves_modes(mode, padding_mode)
affine_xform = AffineTransform( # type: ignore
normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
)
img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0) # type: ignore
if additional_dims:
full_shape = (chns, *spatial_size, *additional_dims)
img = img.reshape(full_shape)
out = _maybe_new_metatensor(img, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore
[docs]
def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> torch.Tensor:
"""
Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
original_affine: original affine of the input image.
spatial_ornt: orientations of the spatial axes,
see also https://nipy.org/nibabel/reference/nibabel.orientations.html
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
xform = nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape)
img = convert_to_tensor(img, track_meta=get_track_meta())
spatial_ornt[:, 0] += 1 # skip channel dim
spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt])
axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1]
full_transpose = np.arange(len(spatial_shape) + 1) # channel-first array
full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0])
extra_info = {"original_affine": original_affine}
shape_np = convert_to_numpy(spatial_shape, wrap_sequence=True)
shape_np = shape_np[[i - 1 for i in full_transpose if i > 0]]
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=shape_np,
affine=xform,
extra_info=extra_info,
orig_size=spatial_shape,
transform_info=transform_info,
lazy=lazy,
)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
if axes:
out = torch.flip(out, dims=axes)
if not np.all(full_transpose == np.arange(len(out.shape))):
out = out.permute(full_transpose.tolist())
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore
[docs]
def flip(img, sp_axes, lazy, transform_info):
"""
Functional implementation of flip.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
sp_axes: spatial axes along which to flip over.
If None, will flip over all of the axes of the input array.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, flipping is performed on all of the axes
specified in the tuple.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist()
extra_info = {"axes": sp_axes} # track the spatial axes
axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes) # use the axes with channel dim
rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)
# axes include the channel dim
xform = torch.eye(int(rank) + 1, dtype=torch.double)
for axis in axes:
sp = axis - 1
xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1
meta_info = TraceableTransform.track_transform_meta(
img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy
)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
out = torch.flip(out, axes)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
[docs]
def resize(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
"""
Functional implementation of resize.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
out_size: expected shape of spatial dimensions after resize operation.
mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``,
``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
align_corners: This only has an effect when mode is
'linear', 'bilinear', 'bicubic' or 'trilinear'.
dtype: data type for resampling computation. If None, use the data type of input data.
input_ndim: number of spatial dimensions.
anti_aliasing: whether to apply a Gaussian filter to smooth the image prior
to downsampling. It is crucial to filter when downsampling
the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
anti_aliasing_sigma: {float, tuple of floats}, optional
Standard deviation for Gaussian filtering used when anti-aliasing.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
extra_info = {
"mode": mode,
"align_corners": align_corners if align_corners is not None else TraceKeys.NONE,
"dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32
"new_dim": len(orig_size) - input_ndim,
}
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=out_size,
affine=scale_affine(orig_size, out_size),
extra_info=extra_info,
orig_size=orig_size,
transform_info=transform_info,
lazy=lazy,
)
if lazy:
if anti_aliasing and lazy:
warnings.warn("anti-aliasing is not compatible with lazy evaluation.")
out = _maybe_new_metatensor(img)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
if tuple(convert_to_numpy(orig_size)) == out_size:
out = _maybe_new_metatensor(img, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
out = _maybe_new_metatensor(img)
img_ = convert_to_tensor(out, dtype=dtype, track_meta=False) # convert to a regular tensor
if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])):
factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(out_size))
if anti_aliasing_sigma is None:
# if sigma is not given, use the default sigma in skimage.transform.resize
anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist()
else:
# if sigma is given, use the given value for downsampling axis
anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(out_size)))
for axis in range(len(out_size)):
anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1)
anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma)
img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False)
_, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1)
resized = torch.nn.functional.interpolate(
input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners
)
out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
[docs]
def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of rotate.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D.
output_shape: output shape of the rotated data.
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``float32``.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
input_ndim = len(im_shape)
if input_ndim not in (2, 3):
raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].")
_angle = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3)
transform = create_rotate(input_ndim, _angle)
if output_shape is None:
corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1))
corners = transform[:-1, :-1] @ corners # type: ignore
output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int)
else:
output_shape = np.asarray(output_shape, dtype=int)
shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist())
shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist())
transform = shift @ transform @ shift_1
extra_info = {
"rot_mat": transform,
"mode": mode,
"padding_mode": padding_mode,
"align_corners": align_corners if align_corners is not None else TraceKeys.NONE,
"dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32
}
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=output_shape,
affine=transform,
extra_info=extra_info,
orig_size=im_shape,
transform_info=transform_info,
lazy=lazy,
)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
_, _m, _p, _ = resolves_modes(mode, padding_mode)
xform = AffineTransform(
normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
)
img_t = out.to(dtype)
transform_t, *_ = convert_to_dst_type(transform, img_t)
output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=tuple(int(i) for i in output_shape))
output = output.float().squeeze(0)
out, *_ = convert_to_dst_type(output, dst=out, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
[docs]
def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of zoom.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
scale_factor: The zoom factor along the spatial axes.
If a float, zoom is the same for each spatial axis.
If a sequence, zoom should contain one value for each spatial axis.
keep_size: Whether keep original size (padding/slicing if needed).
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``float32``.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
output_size = [int(math.floor(float(i) * z)) for i, z in zip(im_shape, scale_factor)]
xform = scale_affine(im_shape, output_size)
extra_info = {
"mode": mode,
"align_corners": align_corners if align_corners is not None else TraceKeys.NONE,
"dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32
"do_padcrop": False,
"padcrop": {},
}
if keep_size:
do_pad_crop = not np.allclose(output_size, im_shape)
if do_pad_crop and lazy: # update for lazy evaluation
_pad_crop = ResizeWithPadOrCrop(spatial_size=im_shape, mode=padding_mode)
_pad_crop.lazy = True
_tmp_img = MetaTensor([], affine=torch.eye(len(output_size) + 1))
_tmp_img.push_pending_operation({LazyAttr.SHAPE: list(output_size), LazyAttr.AFFINE: xform})
lazy_cropped = _pad_crop(_tmp_img)
if isinstance(lazy_cropped, MetaTensor):
xform = lazy_cropped.peek_pending_affine()
extra_info["padcrop"] = lazy_cropped.pending_operations[-1]
extra_info["do_padcrop"] = do_pad_crop
output_size = [int(i) for i in im_shape]
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=output_size,
affine=xform,
extra_info=extra_info,
orig_size=im_shape,
transform_info=transform_info,
lazy=lazy,
)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
img_t = out.to(dtype)
_, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1)
zoomed: NdarrayOrTensor = torch.nn.functional.interpolate(
recompute_scale_factor=True,
input=img_t.unsqueeze(0),
scale_factor=list(scale_factor),
mode=_m,
align_corners=align_corners,
).squeeze(0)
out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32)
if isinstance(out, MetaTensor):
out = out.copy_meta_from(meta_info)
do_pad_crop = not np.allclose(output_size, zoomed.shape[1:])
if do_pad_crop:
_pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=padding_mode)
out = _pad_crop(out)
if get_track_meta() and do_pad_crop:
padcrop_xform = out.applied_operations.pop()
out.applied_operations[-1]["extra_info"]["do_padcrop"] = True
out.applied_operations[-1]["extra_info"]["padcrop"] = padcrop_xform
return out
[docs]
def rotate90(img, axes, k, lazy, transform_info):
"""
Functional implementation of rotate90.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
If axis is negative it counts from the last to the first axis.
k: number of times to rotate by 90 degrees.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
extra_info = {"axes": [d - 1 for d in axes], "k": k}
ori_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
sp_shape = list(ori_shape)
if k in (1, 3):
a_0, a_1 = axes[0] - 1, axes[1] - 1
sp_shape[a_0], sp_shape[a_1] = ori_shape[a_1], ori_shape[a_0]
rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)
r, sp_r = int(rank), len(ori_shape)
xform = to_affine_nd(r, create_translate(sp_r, [-float(d - 1) / 2 for d in sp_shape]))
s = -1.0 if int(axes[0]) - int(axes[1]) in (-1, 2) else 1.0
if sp_r == 2:
rot90 = to_affine_nd(r, create_rotate(sp_r, [s * np.pi / 2]))
else:
idx = {1, 2, 3} - set(axes)
angle: list[float] = [0, 0, 0]
angle[idx.pop() - 1] = s * np.pi / 2
rot90 = to_affine_nd(r, create_rotate(sp_r, angle))
for _ in range(k):
xform = rot90 @ xform
xform = to_affine_nd(r, create_translate(sp_r, [float(d - 1) / 2 for d in ori_shape])) @ xform
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=sp_shape,
affine=xform,
extra_info=extra_info,
orig_size=ori_shape,
transform_info=transform_info,
lazy=lazy,
)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
out = torch.rot90(out, k, axes)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
[docs]
def affine_func(
img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, lazy, transform_info
):
"""
Functional implementation of affine.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).
Args:
img: data to be changed, assuming `img` is channel-first.
affine: the affine transformation to be applied, it can be a 3x3 or 4x4 matrix. This should be defined
for the voxel space spatial centers (``float(size - 1)/2``).
grid: used in non-lazy mode to pre-compute the grid to do the resampling.
resampler: the resampler function, see also: :py:class:`monai.transforms.Resample`.
sp_size: output image spatial size.
mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
and the value represents the order of the spline interpolation.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When `mode` is an integer, using numpy/cupy backends, this argument accepts
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but
skipping the actual (potentially heavy) resampling operation.
image_only: if True return only the image volume, otherwise return (image, affine).
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
# resampler should carry the align_corners and type info
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)
extra_info = {
"affine": affine,
"mode": mode,
"padding_mode": padding_mode,
"do_resampling": do_resampling,
"align_corners": resampler.align_corners,
}
affine = monai.transforms.Affine.compute_w_affine(rank, affine, img_size, sp_size)
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=sp_size,
affine=affine,
extra_info=extra_info,
orig_size=img_size,
transform_info=transform_info,
lazy=lazy,
)
if lazy:
out = _maybe_new_metatensor(img)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
return out if image_only else (out, affine)
if do_resampling:
out = resampler(img=img, grid=grid, mode=mode, padding_mode=padding_mode)
out = _maybe_new_metatensor(out)
else:
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)