# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
from torch import Tensor
from monai.apps.reconstruction.mri_utils import root_sum_of_squares_t
from monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet
from monai.apps.reconstruction.networks.nets.utils import (
reshape_batch_channel_to_channel_dim,
reshape_channel_to_batch_dim,
)
from monai.networks.blocks.fft_utils_t import ifftn_centered_t
[docs]class CoilSensitivityModel(nn.Module):
"""
This class uses a convolutional model to learn coil sensitivity maps for multi-coil MRI reconstruction.
The convolutional model is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet` by default
but can be specified by the user as well. Learning is done on the center of the under-sampled
kspace (that region is fully sampled).
The data being a (complex) 2-channel tensor is a requirement for using this model.
Modified and adopted from: https://github.com/facebookresearch/fastMRI
Args:
spatial_dims: number of spatial dimensions.
features: six integers as numbers of features. denotes number of channels in each layer.
act: activation type and arguments. Defaults to LeakyReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
bias: whether to have a bias term in convolution blocks. Defaults to True.
dropout: dropout ratio. Defaults to 0.0.
upsample: upsampling mode, available options are
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
coil_dim: coil dimension in the data
conv_net: the learning model used to estimate the coil sensitivity maps. default
is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet`. The only
requirement on the model is to have 2 as input and output number of channels.
"""
def __init__(
self,
spatial_dims: int = 2,
features: Sequence[int] = (32, 32, 64, 128, 256, 32),
act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
norm: str | tuple = ("instance", {"affine": True}),
bias: bool = True,
dropout: float | tuple = 0.0,
upsample: str = "deconv",
coil_dim: int = 1,
conv_net: nn.Module | None = None,
):
super().__init__()
if conv_net is None:
self.conv_net = ComplexUnet(
spatial_dims=spatial_dims,
features=features,
act=act,
norm=norm,
bias=bias,
dropout=dropout,
upsample=upsample,
)
else:
# assume the first layer is convolutional and
# check whether in_channels == 2
params = [p.shape for p in conv_net.parameters()]
if params[0][1] != 2:
raise ValueError(f"in_channels should be 2 but it's {params[0][1]}.")
self.conv_net = conv_net # type: ignore
self.spatial_dims = spatial_dims
self.coil_dim = coil_dim
[docs] def get_fully_sampled_region(self, mask: Tensor) -> tuple[int, int]:
"""
Extracts the size of the fully-sampled part of the kspace. Note that when a kspace
is under-sampled, a part of its center is fully sampled. This part is called the Auto
Calibration Region (ACR). ACR is used for sensitivity map computation.
Args:
mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension
Returns:
A tuple containing
(1) left index of the region
(2) right index of the region
Note:
Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right
indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12.
"""
left = right = mask.shape[-2] // 2
while mask[..., right, :]:
right += 1
while mask[..., left, :]:
left -= 1
return left + 1, right
[docs] def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:
"""
Args:
masked_kspace: the under-sampled kspace (which is the input measurement). Its shape
is (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.
mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.
Returns:
predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.
"""
left, right = self.get_fully_sampled_region(mask)
num_low_freqs = right - left # size of the fully-sampled center
# take out the fully-sampled region and set the rest of the data to zero
x = torch.zeros_like(masked_kspace)
start = (mask.shape[-2] - num_low_freqs + 1) // 2 # this marks the start of center extraction
x[..., start : start + num_low_freqs, :] = masked_kspace[..., start : start + num_low_freqs, :]
# apply inverse fourier to the extracted fully-sampled data
x = ifftn_centered_t(x, spatial_dims=self.spatial_dims, is_complex=True)
x, b = reshape_channel_to_batch_dim(x) # shape of x will be (B*C,1,...)
x = self.conv_net(x)
x = reshape_batch_channel_to_channel_dim(x, b) # shape will be (B,C,...)
# normalize the maps
x = x / root_sum_of_squares_t(x, spatial_dim=self.coil_dim).unsqueeze(self.coil_dim)
return x