# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
import monai
from monai.networks import to_norm_affine
from monai.utils import (
GridSampleMode,
GridSamplePadMode,
convert_to_dst_type,
ensure_tuple,
look_up_option,
optional_import,
)
_C, _ = optional_import("monai._C")
__all__ = ["AffineTransform", "grid_pull", "grid_push", "grid_count", "grid_grad"]
class _GridPull(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
output = _C.grid_pull(input, grid, *opt)
if input.requires_grad or grid.requires_grad:
ctx.opt = opt
ctx.save_for_backward(input, grid)
return output
@staticmethod
def backward(ctx, grad):
if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):
return None, None, None, None, None
var = ctx.saved_tensors
opt = ctx.opt
grads = _C.grid_pull_backward(grad, *var, *opt)
if ctx.needs_input_grad[0]:
return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None
if ctx.needs_input_grad[1]:
return None, grads[0], None, None, None
[docs]def grid_pull(
input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True
) -> torch.Tensor:
"""
Sample an image with respect to a deformation field.
`interpolation` can be an int, a string or an InterpolationType.
Possible values are::
- 0 or 'nearest' or InterpolationType.nearest
- 1 or 'linear' or InterpolationType.linear
- 2 or 'quadratic' or InterpolationType.quadratic
- 3 or 'cubic' or InterpolationType.cubic
- 4 or 'fourth' or InterpolationType.fourth
- 5 or 'fifth' or InterpolationType.fifth
- 6 or 'sixth' or InterpolationType.sixth
- 7 or 'seventh' or InterpolationType.seventh
A list of values can be provided, in the order [W, H, D],
to specify dimension-specific interpolation orders.
`bound` can be an int, a string or a BoundType.
Possible values are::
- 0 or 'replicate' or 'nearest' or BoundType.replicate or 'border'
- 1 or 'dct1' or 'mirror' or BoundType.dct1
- 2 or 'dct2' or 'reflect' or BoundType.dct2
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
- 5 or 'dft' or 'wrap' or BoundType.dft
- 7 or 'zero' or 'zeros' or BoundType.zero
A list of values can be provided, in the order [W, H, D],
to specify dimension-specific boundary conditions.
`sliding` is a specific condition than only applies to flow fields
(with as many channels as dimensions). It cannot be dimension-specific.
Note that:
- `dft` corresponds to circular padding
- `dct2` corresponds to Neumann boundary conditions (symmetric)
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
See Also:
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
- https://en.wikipedia.org/wiki/Discrete_sine_transform
- ``help(monai._C.BoundType)``
- ``help(monai._C.InterpolationType)``
Args:
input: Input image. `(B, C, Wi, Hi, Di)`.
grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`.
interpolation (int or list[int] , optional): Interpolation order.
Defaults to `'linear'`.
bound (BoundType, or list[BoundType], optional): Boundary conditions.
Defaults to `'zero'`.
extrapolate: Extrapolate out-of-bound data.
Defaults to `True`.
Returns:
output (torch.Tensor): Deformed image `(B, C, Wo, Ho, Do)`.
"""
# Convert parameters
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
interpolation = [
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
for i in ensure_tuple(interpolation)
]
out: torch.Tensor
out = _GridPull.apply(input, grid, interpolation, bound, extrapolate)
if isinstance(input, monai.data.MetaTensor):
out = convert_to_dst_type(out, dst=input)[0]
return out
class _GridPush(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
output = _C.grid_push(input, grid, shape, *opt)
if input.requires_grad or grid.requires_grad:
ctx.opt = opt
ctx.save_for_backward(input, grid)
return output
@staticmethod
def backward(ctx, grad):
if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):
return None, None, None, None, None, None
var = ctx.saved_tensors
opt = ctx.opt
grads = _C.grid_push_backward(grad, *var, *opt)
if ctx.needs_input_grad[0]:
return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None, None
if ctx.needs_input_grad[1]:
return None, grads[0], None, None, None, None
[docs]def grid_push(
input: torch.Tensor, grid: torch.Tensor, shape=None, interpolation="linear", bound="zero", extrapolate: bool = True
):
"""
Splat an image with respect to a deformation field (pull adjoint).
`interpolation` can be an int, a string or an InterpolationType.
Possible values are::
- 0 or 'nearest' or InterpolationType.nearest
- 1 or 'linear' or InterpolationType.linear
- 2 or 'quadratic' or InterpolationType.quadratic
- 3 or 'cubic' or InterpolationType.cubic
- 4 or 'fourth' or InterpolationType.fourth
- 5 or 'fifth' or InterpolationType.fifth
- 6 or 'sixth' or InterpolationType.sixth
- 7 or 'seventh' or InterpolationType.seventh
A list of values can be provided, in the order `[W, H, D]`,
to specify dimension-specific interpolation orders.
`bound` can be an int, a string or a BoundType.
Possible values are::
- 0 or 'replicate' or 'nearest' or BoundType.replicate
- 1 or 'dct1' or 'mirror' or BoundType.dct1
- 2 or 'dct2' or 'reflect' or BoundType.dct2
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
- 5 or 'dft' or 'wrap' or BoundType.dft
- 7 or 'zero' or BoundType.zero
A list of values can be provided, in the order `[W, H, D]`,
to specify dimension-specific boundary conditions.
`sliding` is a specific condition than only applies to flow fields
(with as many channels as dimensions). It cannot be dimension-specific.
Note that:
- `dft` corresponds to circular padding
- `dct2` corresponds to Neumann boundary conditions (symmetric)
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
See Also:
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
- https://en.wikipedia.org/wiki/Discrete_sine_transform
- ``help(monai._C.BoundType)``
- ``help(monai._C.InterpolationType)``
Args:
input: Input image `(B, C, Wi, Hi, Di)`.
grid: Deformation field `(B, Wi, Hi, Di, 1|2|3)`.
shape: Shape of the source image.
interpolation (int or list[int] , optional): Interpolation order.
Defaults to `'linear'`.
bound (BoundType, or list[BoundType], optional): Boundary conditions.
Defaults to `'zero'`.
extrapolate: Extrapolate out-of-bound data.
Defaults to `True`.
Returns:
output (torch.Tensor): Splatted image `(B, C, Wo, Ho, Do)`.
"""
# Convert parameters
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
interpolation = [
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
for i in ensure_tuple(interpolation)
]
if shape is None:
shape = tuple(input.shape[2:])
out: torch.Tensor = _GridPush.apply(input, grid, shape, interpolation, bound, extrapolate)
if isinstance(input, monai.data.MetaTensor):
out = convert_to_dst_type(out, dst=input)[0]
return out
class _GridCount(torch.autograd.Function):
@staticmethod
def forward(ctx, grid, shape, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
output = _C.grid_count(grid, shape, *opt)
if grid.requires_grad:
ctx.opt = opt
ctx.save_for_backward(grid)
return output
@staticmethod
def backward(ctx, grad):
if ctx.needs_input_grad[0]:
var = ctx.saved_tensors
opt = ctx.opt
return _C.grid_count_backward(grad, *var, *opt), None, None, None, None
return None, None, None, None, None
[docs]def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="zero", extrapolate: bool = True):
"""
Splatting weights with respect to a deformation field (pull adjoint).
This function is equivalent to applying grid_push to an image of ones.
`interpolation` can be an int, a string or an InterpolationType.
Possible values are::
- 0 or 'nearest' or InterpolationType.nearest
- 1 or 'linear' or InterpolationType.linear
- 2 or 'quadratic' or InterpolationType.quadratic
- 3 or 'cubic' or InterpolationType.cubic
- 4 or 'fourth' or InterpolationType.fourth
- 5 or 'fifth' or InterpolationType.fifth
- 6 or 'sixth' or InterpolationType.sixth
- 7 or 'seventh' or InterpolationType.seventh
A list of values can be provided, in the order [W, H, D],
to specify dimension-specific interpolation orders.
`bound` can be an int, a string or a BoundType.
Possible values are::
- 0 or 'replicate' or 'nearest' or BoundType.replicate
- 1 or 'dct1' or 'mirror' or BoundType.dct1
- 2 or 'dct2' or 'reflect' or BoundType.dct2
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
- 5 or 'dft' or 'wrap' or BoundType.dft
- 7 or 'zero' or BoundType.zero
A list of values can be provided, in the order [W, H, D],
to specify dimension-specific boundary conditions.
`sliding` is a specific condition than only applies to flow fields
(with as many channels as dimensions). It cannot be dimension-specific.
Note that:
- `dft` corresponds to circular padding
- `dct2` corresponds to Neumann boundary conditions (symmetric)
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
See Also:
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
- https://en.wikipedia.org/wiki/Discrete_sine_transform
- ``help(monai._C.BoundType)``
- ``help(monai._C.InterpolationType)``
Args:
grid: Deformation field `(B, Wi, Hi, Di, 2|3)`.
shape: shape of the source image.
interpolation (int or list[int] , optional): Interpolation order.
Defaults to `'linear'`.
bound (BoundType, or list[BoundType], optional): Boundary conditions.
Defaults to `'zero'`.
extrapolate (bool, optional): Extrapolate out-of-bound data.
Defaults to `True`.
Returns:
output (torch.Tensor): Splat weights `(B, 1, Wo, Ho, Do)`.
"""
# Convert parameters
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
interpolation = [
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
for i in ensure_tuple(interpolation)
]
if shape is None:
shape = tuple(grid.shape[2:])
out: torch.Tensor = _GridCount.apply(grid, shape, interpolation, bound, extrapolate)
if isinstance(input, monai.data.MetaTensor):
out = convert_to_dst_type(out, dst=input)[0]
return out
class _GridGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid, interpolation, bound, extrapolate):
opt = (bound, interpolation, extrapolate)
output = _C.grid_grad(input, grid, *opt)
if input.requires_grad or grid.requires_grad:
ctx.opt = opt
ctx.save_for_backward(input, grid)
return output
@staticmethod
def backward(ctx, grad):
if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):
return None, None, None, None, None
var = ctx.saved_tensors
opt = ctx.opt
grads = _C.grid_grad_backward(grad, *var, *opt)
if ctx.needs_input_grad[0]:
return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None
if ctx.needs_input_grad[1]:
return None, grads[0], None, None, None
[docs]def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True):
"""
Sample an image with respect to a deformation field.
`interpolation` can be an int, a string or an InterpolationType.
Possible values are::
- 0 or 'nearest' or InterpolationType.nearest
- 1 or 'linear' or InterpolationType.linear
- 2 or 'quadratic' or InterpolationType.quadratic
- 3 or 'cubic' or InterpolationType.cubic
- 4 or 'fourth' or InterpolationType.fourth
- 5 or 'fifth' or InterpolationType.fifth
- 6 or 'sixth' or InterpolationType.sixth
- 7 or 'seventh' or InterpolationType.seventh
A list of values can be provided, in the order [W, H, D],
to specify dimension-specific interpolation orders.
`bound` can be an int, a string or a BoundType.
Possible values are::
- 0 or 'replicate' or 'nearest' or BoundType.replicate
- 1 or 'dct1' or 'mirror' or BoundType.dct1
- 2 or 'dct2' or 'reflect' or BoundType.dct2
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
- 5 or 'dft' or 'wrap' or BoundType.dft
- 7 or 'zero' or BoundType.zero
A list of values can be provided, in the order [W, H, D],
to specify dimension-specific boundary conditions.
`sliding` is a specific condition than only applies to flow fields
(with as many channels as dimensions). It cannot be dimension-specific.
Note that:
- `dft` corresponds to circular padding
- `dct2` corresponds to Neumann boundary conditions (symmetric)
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
See Also:
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
- https://en.wikipedia.org/wiki/Discrete_sine_transform
- ``help(monai._C.BoundType)``
- ``help(monai._C.InterpolationType)``
Args:
input: Input image. `(B, C, Wi, Hi, Di)`.
grid: Deformation field. `(B, Wo, Ho, Do, 2|3)`.
interpolation (int or list[int] , optional): Interpolation order.
Defaults to `'linear'`.
bound (BoundType, or list[BoundType], optional): Boundary conditions.
Defaults to `'zero'`.
extrapolate: Extrapolate out-of-bound data. Defaults to `True`.
Returns:
output (torch.Tensor): Sampled gradients (B, C, Wo, Ho, Do, 1|2|3).
"""
# Convert parameters
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
interpolation = [
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
for i in ensure_tuple(interpolation)
]
out: torch.Tensor = _GridGrad.apply(input, grid, interpolation, bound, extrapolate)
if isinstance(input, monai.data.MetaTensor):
out = convert_to_dst_type(out, dst=input)[0]
return out