Source code for monai.apps.reconstruction.networks.nets.utils

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script contains utility functions for developing new networks/blocks in PyTorch.
"""

from __future__ import annotations

import math

from torch import Tensor
from torch.nn import functional as F

from monai.apps.reconstruction.complex_utils import complex_conj_t, complex_mul_t
from monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t


[docs] def reshape_complex_to_channel_dim(x: Tensor) -> Tensor: """ Swaps the complex dimension with the channel dimension so that the network treats real/imaginary parts as two separate channels. Args: x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data Returns: output of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data """ if x.shape[-1] != 2: raise ValueError(f"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.") if len(x.shape) == 5: # this is 2D b, c, h, w, two = x.shape return x.permute(0, 4, 1, 2, 3).contiguous().view(b, 2 * c, h, w) elif len(x.shape) == 6: # this is 3D b, c, h, w, d, two = x.shape return x.permute(0, 5, 1, 2, 3, 4).contiguous().view(b, 2 * c, h, w, d) else: raise ValueError(f"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}")
[docs] def reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor: """ Swaps the complex dimension with the channel dimension so that the network output has 2 as its last dimension Args: x: input of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data Returns: output of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data """ if x.shape[1] % 2 != 0: raise ValueError(f"channel dimension should be even but ({x.shape[1]}) is odd.") if len(x.shape) == 4: # this is 2D b, c2, h, w = x.shape # c2 means c*2 c = c2 // 2 return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1) elif len(x.shape) == 5: # this is 3D b, c2, h, w, d = x.shape # c2 means c*2 c = c2 // 2 return x.view(b, 2, c, h, w, d).permute(0, 2, 3, 4, 5, 1) else: raise ValueError(f"only 2D (B,C*2,H,W) and 3D (B,C*2,H,W,D) data are supported but x has shape {x.shape}")
[docs] def reshape_channel_to_batch_dim(x: Tensor) -> tuple[Tensor, int]: """ Combines batch and channel dimensions. Args: x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data Returns: A tuple containing: (1) output of shape (B*C,1,...) (2) batch size """ if len(x.shape) == 5: # this is 2D b, c, h, w, two = x.shape return x.contiguous().view(b * c, 1, h, w, two), b elif len(x.shape) == 6: # this is 3D b, c, h, w, d, two = x.shape return x.contiguous().view(b * c, 1, h, w, d, two), b else: raise ValueError(f"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}")
[docs] def reshape_batch_channel_to_channel_dim(x: Tensor, batch_size: int) -> Tensor: """ Detaches batch and channel dimensions. Args: x: input of shape (B*C,1,H,W,2) for 2D data or (B*C,1,H,W,D,2) for 3D data batch_size: batch size Returns: output of shape (B,C,...) """ if len(x.shape) == 5: # this is 2D bc, one, h, w, two = x.shape # bc represents B*C c = bc // batch_size return x.view(batch_size, c, h, w, two) elif len(x.shape) == 6: # this is 3D bc, one, h, w, d, two = x.shape # bc represents B*C c = bc // batch_size return x.view(batch_size, c, h, w, d, two) else: raise ValueError(f"only 2D (B*C,1,H,W,2) and 3D (B*C,1,H,W,D,2) data are supported but x has shape {x.shape}")
[docs] def complex_normalize(x: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ Performs layer mean-std normalization for complex data. Normalization is done for each batch member along each part (part refers to real and imaginary parts), separately. Args: x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data Returns: A tuple containing (1) normalized output of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data (2) mean (3) std """ if len(x.shape) == 4: # this is 2D b, c, h, w = x.shape x = x.contiguous().view(b, 2, c // 2 * h * w) mean = x.mean(dim=2).view(b, 2, 1, 1, 1).expand(b, 2, c // 2, 1, 1).contiguous().view(b, c, 1, 1) std = x.std(dim=2, unbiased=False).view(b, 2, 1, 1, 1).expand(b, 2, c // 2, 1, 1).contiguous().view(b, c, 1, 1) x = x.view(b, c, h, w) return (x - mean) / std, mean, std elif len(x.shape) == 5: # this is 3D b, c, h, w, d = x.shape x = x.contiguous().view(b, 2, c // 2 * h * w * d) mean = x.mean(dim=2).view(b, 2, 1, 1, 1, 1).expand(b, 2, c // 2, 1, 1, 1).contiguous().view(b, c, 1, 1, 1) std = ( x.std(dim=2, unbiased=False) .view(b, 2, 1, 1, 1, 1) .expand(b, 2, c // 2, 1, 1, 1) .contiguous() .view(b, c, 1, 1, 1) ) x = x.view(b, c, h, w, d) return (x - mean) / std, mean, std else: raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
[docs] def divisible_pad_t( x: Tensor, k: int = 16 ) -> tuple[Tensor, tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]]: """ Pad input to feed into the network (torch script compatible) Args: x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data k: padding factor. each padded dimension will be divisible by k. Returns: A tuple containing (1) padded input (2) pad sizes (in order to reverse padding if needed) Example: .. code-block:: python import torch # 2D data x = torch.ones([3,2,50,70]) x_pad,padding_sizes = divisible_pad_t(x, k=16) # the following line should print (3, 2, 64, 80) print(x_pad.shape) # 3D data x = torch.ones([3,2,50,70,80]) x_pad,padding_sizes = divisible_pad_t(x, k=16) # the following line should print (3, 2, 64, 80, 80) print(x_pad.shape) """ if len(x.shape) == 4: # this is 2D b, c, h, w = x.shape w_mult = ((w - 1) | (k - 1)) + 1 # OR with (k-1) and then +1 makes sure padding is divisible by k h_mult = ((h - 1) | (k - 1)) + 1 w_pad = floor_ceil((w_mult - w) / 2) h_pad = floor_ceil((h_mult - h) / 2) x = F.pad(x, w_pad + h_pad) # dummy values for the 3rd spatial dimension d_mult = -1 d_pad = (-1, -1) pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult) elif len(x.shape) == 5: # this is 3D b, c, h, w, d = x.shape w_mult = ((w - 1) | (k - 1)) + 1 h_mult = ((h - 1) | (k - 1)) + 1 d_mult = ((d - 1) | (k - 1)) + 1 w_pad = floor_ceil((w_mult - w) / 2) h_pad = floor_ceil((h_mult - h) / 2) d_pad = floor_ceil((d_mult - d) / 2) x = F.pad(x, d_pad + w_pad + h_pad) pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult) else: raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}") return x, pad_sizes
[docs] def inverse_divisible_pad_t( x: Tensor, pad_sizes: tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int] ) -> Tensor: """ De-pad network output to match its original shape Args: x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data pad_sizes: padding values Returns: de-padded input """ h_pad, w_pad, d_pad, h_mult, w_mult, d_mult = pad_sizes if len(x.shape) == 4: # this is 2D return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] elif len(x.shape) == 5: # this is 3D return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1], d_pad[0] : d_mult - d_pad[1]] else: raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
[docs] def floor_ceil(n: float) -> tuple[int, int]: """ Returns floor and ceil of the input Args: n: input number Returns: A tuple containing: (1) floor(n) (2) ceil(n) """ return math.floor(n), math.ceil(n)
[docs] def sensitivity_map_reduce(kspace: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor: """ Reduces coil measurements to a corresponding image based on the given sens_maps. Let's say there are C coil measurements inside kspace, then this function multiplies the conjugate of each coil sensitivity map with the corresponding coil image. The result of this process will be C images. Summing those images together gives the resulting "reduced image." Args: kspace: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the coil dimension. 3D data will have the shape (B,C,H,W,D,2). sens_maps: sensitivity maps of the same shape as input x. spatial_dims: is 2 for 2D data and is 3 for 3D data Returns: reduction of x to (B,1,H,W,2) for 2D data or (B,1,H,W,D,2) for 3D data. """ img = ifftn_centered_t(kspace, spatial_dims=spatial_dims, is_complex=True) # inverse fourier transform return complex_mul_t(img, complex_conj_t(sens_maps)).sum(dim=1, keepdim=True)
[docs] def sensitivity_map_expand(img: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor: """ Expands an image to its corresponding coil images based on the given sens_maps. Let's say there are C coils. This function multiples image img with each coil sensitivity map in sens_maps and stacks the resulting C coil images along the channel dimension which is reserved for coils. Args: img: 2D image (B,1,H,W,2) with the last dimension being 2 (for real/imaginary parts). 3D data will have the shape (B,1,H,W,D,2). sens_maps: Sensitivity maps for combining coil images. The shape is (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data (C denotes the coil dimension). spatial_dims: is 2 for 2D data and is 3 for 3D data Returns: Expansion of x to (B,C,H,W,2) for 2D data and (B,C,H,W,D,2) for 3D data. The output is transferred to the frequency domain to yield coil measurements. """ return fftn_centered_t(complex_mul_t(img, sens_maps), spatial_dims=spatial_dims, is_complex=True)