# Source code for monai.apps.reconstruction.networks.nets.complex_unet

```
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch.nn as nn
from torch import Tensor
from monai.apps.reconstruction.networks.nets.utils import (
complex_normalize,
divisible_pad_t,
inverse_divisible_pad_t,
reshape_channel_complex_to_last_dim,
reshape_complex_to_channel_dim,
)
from monai.networks.nets.basic_unet import BasicUNet
[docs]class ComplexUnet(nn.Module):
"""
This variant of U-Net handles complex-value input/output. It can be
used as a model to learn sensitivity maps in multi-coil MRI data. It is
built based on :py:class:`monai.networks.nets.BasicUNet` by default but the user
can input their convolutional model as well.
ComplexUnet also applies default normalization to the input which makes it more stable to train.
The data being a (complex) 2-channel tensor is a requirement for using this model.
Modified and adopted from: https://github.com/facebookresearch/fastMRI
Args:
spatial_dims: number of spatial dimensions.
features: six integers as numbers of features. denotes number of channels in each layer.
act: activation type and arguments. Defaults to LeakyReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
bias: whether to have a bias term in convolution blocks. Defaults to True.
dropout: dropout ratio. Defaults to 0.0.
upsample: upsampling mode, available options are
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
pad_factor: an integer denoting the number which each padded dimension will be divisible to.
For example, 16 means each dimension will be divisible by 16 after padding
conv_net: the learning model used inside the ComplexUnet. The default
is :py:class:`monai.networks.nets.basic_unet`. The only requirement on the model is to
have 2 as input and output number of channels.
"""
def __init__(
self,
spatial_dims: int = 2,
features: Sequence[int] = (32, 32, 64, 128, 256, 32),
act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
norm: str | tuple = ("instance", {"affine": True}),
bias: bool = True,
dropout: float | tuple = 0.0,
upsample: str = "deconv",
pad_factor: int = 16,
conv_net: nn.Module | None = None,
):
super().__init__()
self.unet: nn.Module
if conv_net is None:
self.unet = BasicUNet(
spatial_dims=spatial_dims,
in_channels=2,
out_channels=2,
features=features,
act=act,
norm=norm,
bias=bias,
dropout=dropout,
upsample=upsample,
)
else:
# assume the first layer is convolutional and
# check whether in_channels == 2
params = [p.shape for p in conv_net.parameters()]
if params[0][1] != 2:
raise ValueError(f"in_channels should be 2 but it's {params[0][1]}.")
self.unet = conv_net
self.pad_factor = pad_factor
[docs] def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
Returns:
output of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
"""
# suppose the input is 2D, the comment in front of each operator below shows the shape after that operator
x = reshape_complex_to_channel_dim(x) # x will be of shape (B,C*2,H,W)
x, mean, std = complex_normalize(x) # x will be of shape (B,C*2,H,W)
# pad input
x, padding_sizes = divisible_pad_t(
x, k=self.pad_factor
) # x will be of shape (B,C*2,H',W') where H' and W' are for after padding
x = self.unet(x)
# inverse padding
x = inverse_divisible_pad_t(x, padding_sizes) # x will be of shape (B,C*2,H,W)
x = x * std + mean
x = reshape_channel_complex_to_last_dim(x) # x will be of shape (B,C,H,W,2)
return x
```