# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Sequence
import numpy as np
from torch import Tensor
from monai.apps.reconstruction.complex_utils import complex_abs, convert_to_tensor_complex
from monai.apps.reconstruction.mri_utils import root_sum_of_squares
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.fft_utils import ifftn_centered
from monai.transforms.transform import RandomizableTransform
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_to_tensor
[docs]
class KspaceMask(RandomizableTransform):
"""
A basic class for under-sampling mask setup. It provides common
features for under-sampling mask generators.
For example, RandomMaskFunc and EquispacedMaskFunc (two mask
transform objects defined right after this module)
both inherit MaskFunc to properly setup properties like the
acceleration factor.
"""
[docs]
def __init__(
self,
center_fractions: Sequence[float],
accelerations: Sequence[float],
spatial_dims: int = 2,
is_complex: bool = True,
):
"""
Args:
center_fractions: Fraction of low-frequency columns to be retained.
If multiple values are provided, then one of these numbers
is chosen uniformly each time.
accelerations: Amount of under-sampling. This should have the
same length as center_fractions. If multiple values are
provided, then one of these is chosen uniformly each time.
spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data;
it's also 2 for pseudo-3D datasets like the fastMRI dataset).
The last spatial dim is selected for sampling. For the fastMRI
dataset, k-space has the form (...,num_slices,num_coils,H,W)
and sampling is done along W. For a general 3D data with the
shape (...,num_coils,H,W,D), sampling is done along D.
is_complex: if True, then the last dimension will be reserved for
real/imaginary parts.
"""
if len(center_fractions) != len(accelerations):
raise ValueError(
"Number of center fractions \
should match number of accelerations"
)
self.center_fractions = center_fractions
self.accelerations = accelerations
self.spatial_dims = spatial_dims
self.is_complex = is_complex
[docs]
@abstractmethod
def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:
"""
This is an extra instance to allow for defining new mask generators.
For creating other mask transforms, define a new class and simply
override __call__. See an example of this in
:py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`.
Args:
kspace: The input k-space data. The shape is (...,num_coils,H,W,2)
for complex 2D inputs and (...,num_coils,H,W,D) for real 3D
data.
"""
raise NotImplementedError
[docs]
def randomize_choose_acceleration(self) -> Sequence[float]:
"""
If multiple values are provided for center_fractions and
accelerations, this function selects one value uniformly
for each training/test sample.
Returns:
A tuple containing
(1) center_fraction: chosen fraction of center kspace
lines to exclude from under-sampling
(2) acceleration: chosen acceleration factor
"""
choice = self.R.randint(0, len(self.accelerations))
center_fraction = self.center_fractions[choice]
acceleration = self.accelerations[choice]
return center_fraction, acceleration
[docs]
class RandomKspaceMask(KspaceMask):
"""
This k-space mask transform under-samples the k-space according to a
random sampling pattern. Precisely, it uniformly selects a subset of
columns from the input k-space data. If the k-space data has N columns,
the mask picks out:
1. N_low_freqs = (N * center_fraction) columns in the center
corresponding to low-frequencies
2. The other columns are selected uniformly at random with a probability
equal to:
prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs).
This ensures that the expected number of columns selected is equal to
(N / acceleration)
It is possible to use multiple center_fractions and accelerations,
in which case one possible (center_fraction, acceleration) is chosen
uniformly at random each time the transform is called.
Example:
If accelerations = [4, 8] and center_fractions = [0.08, 0.04],
then there is a 50% probability that 4-fold acceleration with 8%
center fraction is selected and a 50% probability that 8-fold
acceleration with 4% center fraction is selected.
Modified and adopted from:
https://github.com/facebookresearch/fastMRI/tree/master/fastmri
"""
backend = [TransformBackends.TORCH]
[docs]
def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:
"""
Args:
kspace: The input k-space data. The shape is (...,num_coils,H,W,2)
for complex 2D inputs and (...,num_coils,H,W,D) for real 3D
data. The last spatial dim is selected for sampling. For the
fastMRI dataset, k-space has the form
(...,num_slices,num_coils,H,W) and sampling is done along W.
For a general 3D data with the shape (...,num_coils,H,W,D),
sampling is done along D.
Returns:
A tuple containing
(1) the under-sampled kspace
(2) absolute value of the inverse fourier of the under-sampled kspace
"""
kspace_t = convert_to_tensor_complex(kspace)
spatial_size = kspace_t.shape
num_cols = spatial_size[-1]
if self.is_complex: # for complex data
num_cols = spatial_size[-2]
center_fraction, acceleration = self.randomize_choose_acceleration()
# Create the mask
num_low_freqs = int(round(num_cols * center_fraction))
prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)
mask = self.R.uniform(size=num_cols) < prob
pad = (num_cols - num_low_freqs + 1) // 2
mask[pad : pad + num_low_freqs] = True
# Reshape the mask
mask_shape = [1 for _ in spatial_size]
if self.is_complex:
mask_shape[-2] = num_cols
else:
mask_shape[-1] = num_cols
mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32))
# under-sample the ksapce
masked = mask * kspace_t
masked_kspace: Tensor = convert_to_tensor(masked)
self.mask = mask
# compute inverse fourier of the masked kspace
masked_kspace_ifft: Tensor = convert_to_tensor(
complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex))
)
# combine coil images (it is assumed that the coil dimension is
# the first dimension before spatial dimensions)
masked_kspace_ifft_rss: Tensor = convert_to_tensor(
root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1)
)
return masked_kspace, masked_kspace_ifft_rss
[docs]
class EquispacedKspaceMask(KspaceMask):
"""
This k-space mask transform under-samples the k-space according to an
equi-distant sampling pattern. Precisely, it selects an equi-distant
subset of columns from the input k-space data. If the k-space data has N
columns, the mask picks out:
1. N_low_freqs = (N * center_fraction) columns in the center corresponding
to low-frequencies
2. The other columns are selected with equal spacing at a proportion that
reaches the desired acceleration rate taking into consideration the number
of low frequencies. This ensures that the expected number of columns
selected is equal to (N / acceleration)
It is possible to use multiple center_fractions and accelerations, in
which case one possible (center_fraction, acceleration) is chosen
uniformly at random each time the EquispacedMaskFunc object is called.
Example:
If accelerations = [4, 8] and center_fractions = [0.08, 0.04],
then there is a 50% probability that 4-fold acceleration with 8%
center fraction is selected and a 50% probability that 8-fold
acceleration with 4% center fraction is selected.
Modified and adopted from:
https://github.com/facebookresearch/fastMRI/tree/master/fastmri
"""
backend = [TransformBackends.TORCH]
[docs]
def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:
"""
Args:
kspace: The input k-space data. The shape is (...,num_coils,H,W,2)
for complex 2D inputs and (...,num_coils,H,W,D) for real 3D
data. The last spatial dim is selected for sampling. For the
fastMRI multi-coil dataset, k-space has the form
(...,num_slices,num_coils,H,W) and sampling is done along W.
For a general 3D data with the shape (...,num_coils,H,W,D),
sampling is done along D.
Returns:
A tuple containing
(1) the under-sampled kspace
(2) absolute value of the inverse fourier of the under-sampled kspace
"""
kspace_t = convert_to_tensor_complex(kspace)
spatial_size = kspace_t.shape
num_cols = spatial_size[-1]
if self.is_complex: # for complex data
num_cols = spatial_size[-2]
center_fraction, acceleration = self.randomize_choose_acceleration()
num_low_freqs = int(round(num_cols * center_fraction))
# Create the mask
mask = np.zeros(num_cols, dtype=np.float32)
pad = (num_cols - num_low_freqs + 1) // 2
mask[pad : pad + num_low_freqs] = True
# Determine acceleration rate by adjusting for the
# number of low frequencies
adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)
offset = self.R.randint(0, round(adjusted_accel))
accel_samples = np.arange(offset, num_cols - 1, adjusted_accel)
accel_samples = np.around(accel_samples).astype(np.uint)
mask[accel_samples] = True
# Reshape the mask
mask_shape = [1 for _ in spatial_size]
if self.is_complex:
mask_shape[-2] = num_cols
else:
mask_shape[-1] = num_cols
mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32))
# under-sample the ksapce
masked = mask * kspace_t
masked_kspace: Tensor = convert_to_tensor(masked)
self.mask = mask
# compute inverse fourier of the masked kspace
masked_kspace_ifft: Tensor = convert_to_tensor(
complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex))
)
# combine coil images (it is assumed that the coil dimension is
# the first dimension before spatial dimensions)
masked_kspace_ifft_rss: Tensor = convert_to_tensor(
root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1)
)
return masked_kspace, masked_kspace_ifft_rss