# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script contains utility functions for developing new networks/blocks in PyTorch.
"""
from __future__ import annotations
import math
from torch import Tensor
from torch.nn import functional as F
from monai.apps.reconstruction.complex_utils import complex_conj_t, complex_mul_t
from monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t
[docs]def reshape_complex_to_channel_dim(x: Tensor) -> Tensor:
"""
Swaps the complex dimension with the channel dimension so that the network treats real/imaginary
parts as two separate channels.
Args:
x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
Returns:
output of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data
"""
if x.shape[-1] != 2:
raise ValueError(f"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.")
if len(x.shape) == 5: # this is 2D
b, c, h, w, two = x.shape
return x.permute(0, 4, 1, 2, 3).contiguous().view(b, 2 * c, h, w)
elif len(x.shape) == 6: # this is 3D
b, c, h, w, d, two = x.shape
return x.permute(0, 5, 1, 2, 3, 4).contiguous().view(b, 2 * c, h, w, d)
else:
raise ValueError(f"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}")
[docs]def reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor:
"""
Swaps the complex dimension with the channel dimension so that the network output has 2 as its last dimension
Args:
x: input of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data
Returns:
output of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
"""
if x.shape[1] % 2 != 0:
raise ValueError(f"channel dimension should be even but ({x.shape[1]}) is odd.")
if len(x.shape) == 4: # this is 2D
b, c2, h, w = x.shape # c2 means c*2
c = c2 // 2
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1)
elif len(x.shape) == 5: # this is 3D
b, c2, h, w, d = x.shape # c2 means c*2
c = c2 // 2
return x.view(b, 2, c, h, w, d).permute(0, 2, 3, 4, 5, 1)
else:
raise ValueError(f"only 2D (B,C*2,H,W) and 3D (B,C*2,H,W,D) data are supported but x has shape {x.shape}")
[docs]def reshape_channel_to_batch_dim(x: Tensor) -> tuple[Tensor, int]:
"""
Combines batch and channel dimensions.
Args:
x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
Returns:
A tuple containing:
(1) output of shape (B*C,1,...)
(2) batch size
"""
if len(x.shape) == 5: # this is 2D
b, c, h, w, two = x.shape
return x.contiguous().view(b * c, 1, h, w, two), b
elif len(x.shape) == 6: # this is 3D
b, c, h, w, d, two = x.shape
return x.contiguous().view(b * c, 1, h, w, d, two), b
else:
raise ValueError(f"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}")
[docs]def reshape_batch_channel_to_channel_dim(x: Tensor, batch_size: int) -> Tensor:
"""
Detaches batch and channel dimensions.
Args:
x: input of shape (B*C,1,H,W,2) for 2D data or (B*C,1,H,W,D,2) for 3D data
batch_size: batch size
Returns:
output of shape (B,C,...)
"""
if len(x.shape) == 5: # this is 2D
bc, one, h, w, two = x.shape # bc represents B*C
c = bc // batch_size
return x.view(batch_size, c, h, w, two)
elif len(x.shape) == 6: # this is 3D
bc, one, h, w, d, two = x.shape # bc represents B*C
c = bc // batch_size
return x.view(batch_size, c, h, w, d, two)
else:
raise ValueError(f"only 2D (B*C,1,H,W,2) and 3D (B*C,1,H,W,D,2) data are supported but x has shape {x.shape}")
[docs]def complex_normalize(x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""
Performs layer mean-std normalization for complex data. Normalization is done for each batch member
along each part (part refers to real and imaginary parts), separately.
Args:
x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
Returns:
A tuple containing
(1) normalized output of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
(2) mean
(3) std
"""
if len(x.shape) == 4: # this is 2D
b, c, h, w = x.shape
x = x.contiguous().view(b, 2, c // 2 * h * w)
mean = x.mean(dim=2).view(b, 2, 1, 1, 1).expand(b, 2, c // 2, 1, 1).contiguous().view(b, c, 1, 1)
std = x.std(dim=2, unbiased=False).view(b, 2, 1, 1, 1).expand(b, 2, c // 2, 1, 1).contiguous().view(b, c, 1, 1)
x = x.view(b, c, h, w)
return (x - mean) / std, mean, std
elif len(x.shape) == 5: # this is 3D
b, c, h, w, d = x.shape
x = x.contiguous().view(b, 2, c // 2 * h * w * d)
mean = x.mean(dim=2).view(b, 2, 1, 1, 1, 1).expand(b, 2, c // 2, 1, 1, 1).contiguous().view(b, c, 1, 1, 1)
std = (
x.std(dim=2, unbiased=False)
.view(b, 2, 1, 1, 1, 1)
.expand(b, 2, c // 2, 1, 1, 1)
.contiguous()
.view(b, c, 1, 1, 1)
)
x = x.view(b, c, h, w, d)
return (x - mean) / std, mean, std
else:
raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
[docs]def divisible_pad_t(
x: Tensor, k: int = 16
) -> tuple[Tensor, tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]]:
"""
Pad input to feed into the network (torch script compatible)
Args:
x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
k: padding factor. each padded dimension will be divisible by k.
Returns:
A tuple containing
(1) padded input
(2) pad sizes (in order to reverse padding if needed)
Example:
.. code-block:: python
import torch
# 2D data
x = torch.ones([3,2,50,70])
x_pad,padding_sizes = divisible_pad_t(x, k=16)
# the following line should print (3, 2, 64, 80)
print(x_pad.shape)
# 3D data
x = torch.ones([3,2,50,70,80])
x_pad,padding_sizes = divisible_pad_t(x, k=16)
# the following line should print (3, 2, 64, 80, 80)
print(x_pad.shape)
"""
if len(x.shape) == 4: # this is 2D
b, c, h, w = x.shape
w_mult = ((w - 1) | (k - 1)) + 1 # OR with (k-1) and then +1 makes sure padding is divisible by k
h_mult = ((h - 1) | (k - 1)) + 1
w_pad = floor_ceil((w_mult - w) / 2)
h_pad = floor_ceil((h_mult - h) / 2)
x = F.pad(x, w_pad + h_pad)
# dummy values for the 3rd spatial dimension
d_mult = -1
d_pad = (-1, -1)
pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult)
elif len(x.shape) == 5: # this is 3D
b, c, h, w, d = x.shape
w_mult = ((w - 1) | (k - 1)) + 1
h_mult = ((h - 1) | (k - 1)) + 1
d_mult = ((d - 1) | (k - 1)) + 1
w_pad = floor_ceil((w_mult - w) / 2)
h_pad = floor_ceil((h_mult - h) / 2)
d_pad = floor_ceil((d_mult - d) / 2)
x = F.pad(x, d_pad + w_pad + h_pad)
pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult)
else:
raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
return x, pad_sizes
[docs]def inverse_divisible_pad_t(
x: Tensor, pad_sizes: tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]
) -> Tensor:
"""
De-pad network output to match its original shape
Args:
x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
pad_sizes: padding values
Returns:
de-padded input
"""
h_pad, w_pad, d_pad, h_mult, w_mult, d_mult = pad_sizes
if len(x.shape) == 4: # this is 2D
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
elif len(x.shape) == 5: # this is 3D
return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1], d_pad[0] : d_mult - d_pad[1]]
else:
raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
[docs]def floor_ceil(n: float) -> tuple[int, int]:
"""
Returns floor and ceil of the input
Args:
n: input number
Returns:
A tuple containing:
(1) floor(n)
(2) ceil(n)
"""
return math.floor(n), math.ceil(n)
[docs]def sensitivity_map_reduce(kspace: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor:
"""
Reduces coil measurements to a corresponding image based on the given sens_maps. Let's say there
are C coil measurements inside kspace, then this function multiplies the conjugate of each coil sensitivity map with the
corresponding coil image. The result of this process will be C images. Summing those images together gives the
resulting "reduced image."
Args:
kspace: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the
coil dimension. 3D data will have the shape (B,C,H,W,D,2).
sens_maps: sensitivity maps of the same shape as input x.
spatial_dims: is 2 for 2D data and is 3 for 3D data
Returns:
reduction of x to (B,1,H,W,2) for 2D data or (B,1,H,W,D,2) for 3D data.
"""
img = ifftn_centered_t(kspace, spatial_dims=spatial_dims, is_complex=True) # inverse fourier transform
return complex_mul_t(img, complex_conj_t(sens_maps)).sum(dim=1, keepdim=True)
[docs]def sensitivity_map_expand(img: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor:
"""
Expands an image to its corresponding coil images based on the given sens_maps. Let's say there
are C coils. This function multiples image img with each coil sensitivity map in sens_maps and stacks
the resulting C coil images along the channel dimension which is reserved for coils.
Args:
img: 2D image (B,1,H,W,2) with the last dimension being 2 (for real/imaginary parts). 3D data will have
the shape (B,1,H,W,D,2).
sens_maps: Sensitivity maps for combining coil images. The shape is (B,C,H,W,2) for 2D data
or (B,C,H,W,D,2) for 3D data (C denotes the coil dimension).
spatial_dims: is 2 for 2D data and is 3 for 3D data
Returns:
Expansion of x to (B,C,H,W,2) for 2D data and (B,C,H,W,D,2) for 3D data. The output is transferred
to the frequency domain to yield coil measurements.
"""
return fftn_centered_t(complex_mul_t(img, sens_maps), spatial_dims=spatial_dims, is_complex=True)