# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F
from monai.networks.blocks.regunet_block import (
RegistrationDownSampleBlock,
RegistrationExtractionBlock,
RegistrationResidualConvBlock,
get_conv_block,
get_deconv_block,
)
from monai.networks.utils import meshgrid_ij
__all__ = ["RegUNet", "AffineHead", "GlobalNet", "LocalNet"]
[docs]class RegUNet(nn.Module):
"""
Class that implements an adapted UNet. This class also serve as the parent class of LocalNet and GlobalNet
Reference:
O. Ronneberger, P. Fischer, and T. Brox,
“U-net: Convolutional networks for biomedical image segmentation,”,
Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
https://arxiv.org/abs/1505.04597
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
depth: int,
out_kernel_initializer: str | None = "kaiming_uniform",
out_activation: str | None = None,
out_channels: int = 3,
extract_levels: tuple[int] | None = None,
pooling: bool = True,
concat_skip: bool = False,
encode_kernel_sizes: int | list[int] = 3,
):
"""
Args:
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
depth: input is at level 0, bottom is at level depth.
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
encode_kernel_sizes: kernel size for down-sampling
"""
super().__init__()
if not extract_levels:
extract_levels = (depth,)
if max(extract_levels) != depth:
raise AssertionError
# save parameters
self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.num_channel_initial = num_channel_initial
self.depth = depth
self.out_kernel_initializer = out_kernel_initializer
self.out_activation = out_activation
self.out_channels = out_channels
self.extract_levels = extract_levels
self.pooling = pooling
self.concat_skip = concat_skip
if isinstance(encode_kernel_sizes, int):
encode_kernel_sizes = [encode_kernel_sizes] * (self.depth + 1)
if len(encode_kernel_sizes) != self.depth + 1:
raise AssertionError
self.encode_kernel_sizes: list[int] = encode_kernel_sizes
self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)]
self.min_extract_level = min(self.extract_levels)
# init layers
# all lists start with d = 0
self.encode_convs: nn.ModuleList
self.encode_pools: nn.ModuleList
self.bottom_block: nn.Sequential
self.decode_deconvs: nn.ModuleList
self.decode_convs: nn.ModuleList
self.output_block: nn.Module
# build layers
self.build_layers()
def build_layers(self):
self.build_encode_layers()
self.build_decode_layers()
def build_encode_layers(self):
# encoding / down-sampling
self.encode_convs = nn.ModuleList(
[
self.build_conv_block(
in_channels=self.in_channels if d == 0 else self.num_channels[d - 1],
out_channels=self.num_channels[d],
kernel_size=self.encode_kernel_sizes[d],
)
for d in range(self.depth)
]
)
self.encode_pools = nn.ModuleList(
[self.build_down_sampling_block(channels=self.num_channels[d]) for d in range(self.depth)]
)
self.bottom_block = self.build_bottom_block(
in_channels=self.num_channels[-2], out_channels=self.num_channels[-1]
)
def build_conv_block(self, in_channels, out_channels, kernel_size):
return nn.Sequential(
get_conv_block(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
RegistrationResidualConvBlock(
spatial_dims=self.spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
)
def build_down_sampling_block(self, channels: int):
return RegistrationDownSampleBlock(spatial_dims=self.spatial_dims, channels=channels, pooling=self.pooling)
def build_bottom_block(self, in_channels: int, out_channels: int):
kernel_size = self.encode_kernel_sizes[self.depth]
return nn.Sequential(
get_conv_block(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
RegistrationResidualConvBlock(
spatial_dims=self.spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
)
def build_decode_layers(self):
self.decode_deconvs = nn.ModuleList(
[
self.build_up_sampling_block(in_channels=self.num_channels[d + 1], out_channels=self.num_channels[d])
for d in range(self.depth - 1, self.min_extract_level - 1, -1)
]
)
self.decode_convs = nn.ModuleList(
[
self.build_conv_block(
in_channels=(2 * self.num_channels[d] if self.concat_skip else self.num_channels[d]),
out_channels=self.num_channels[d],
kernel_size=3,
)
for d in range(self.depth - 1, self.min_extract_level - 1, -1)
]
)
# extraction
self.output_block = self.build_output_block()
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)
def build_output_block(self) -> nn.Module:
return RegistrationExtractionBlock(
spatial_dims=self.spatial_dims,
extract_levels=self.extract_levels,
num_channels=self.num_channels,
out_channels=self.out_channels,
kernel_initializer=self.out_kernel_initializer,
activation=self.out_activation,
)
[docs] def forward(self, x):
"""
Args:
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
Returns:
Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]), with the same spatial size as ``x``
"""
image_size = x.shape[2:]
skips = [] # [0, ..., depth - 1]
encoded = x
for encode_conv, encode_pool in zip(self.encode_convs, self.encode_pools):
skip = encode_conv(encoded)
encoded = encode_pool(skip)
skips.append(skip)
decoded = self.bottom_block(encoded)
outs = [decoded]
for i, (decode_deconv, decode_conv) in enumerate(zip(self.decode_deconvs, self.decode_convs)):
decoded = decode_deconv(decoded)
if self.concat_skip:
decoded = torch.cat([decoded, skips[-i - 1]], dim=1)
else:
decoded = decoded + skips[-i - 1]
decoded = decode_conv(decoded)
outs.append(decoded)
out = self.output_block(outs, image_size=image_size)
return out
class AffineHead(nn.Module):
def __init__(
self,
spatial_dims: int,
image_size: list[int],
decode_size: list[int],
in_channels: int,
save_theta: bool = False,
):
"""
Args:
spatial_dims: number of spatial dimensions
image_size: output spatial size
decode_size: input spatial size (two or three integers depending on ``spatial_dims``)
in_channels: number of input channels
save_theta: whether to save the theta matrix estimation
"""
super().__init__()
self.spatial_dims = spatial_dims
if spatial_dims == 2:
in_features = in_channels * decode_size[0] * decode_size[1]
out_features = 6
out_init = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)
elif spatial_dims == 3:
in_features = in_channels * decode_size[0] * decode_size[1] * decode_size[2]
out_features = 12
out_init = torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], dtype=torch.float)
else:
raise ValueError(f"only support 2D/3D operation, got spatial_dims={spatial_dims}")
self.fc = nn.Linear(in_features=in_features, out_features=out_features)
self.grid = self.get_reference_grid(image_size) # (spatial_dims, ...)
# init weight/bias
self.fc.weight.data.zero_()
self.fc.bias.data.copy_(out_init)
self.save_theta = save_theta
self.theta = torch.Tensor()
@staticmethod
def get_reference_grid(image_size: tuple[int] | list[int]) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in image_size]
grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...)
return grid.to(dtype=torch.float)
def affine_transform(self, theta: torch.Tensor):
# (spatial_dims, ...) -> (spatial_dims + 1, ...)
grid_padded = torch.cat([self.grid, torch.ones_like(self.grid[:1])])
# grid_warped[b,p,...] = sum_over_q(grid_padded[q,...] * theta[b,p,q]
if self.spatial_dims == 2:
grid_warped = torch.einsum("qij,bpq->bpij", grid_padded, theta.reshape(-1, 2, 3))
elif self.spatial_dims == 3:
grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
else:
raise ValueError(f"do not support spatial_dims={self.spatial_dims}")
return grid_warped
def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor:
f = x[0]
self.grid = self.grid.to(device=f.device)
theta = self.fc(f.reshape(f.shape[0], -1))
if self.save_theta:
self.theta = theta.detach()
out: torch.Tensor = self.affine_transform(theta) - self.grid
return out
[docs]class GlobalNet(RegUNet):
"""
Build GlobalNet for image registration.
Reference:
Hu, Yipeng, et al.
"Label-driven weakly-supervised learning
for multimodal deformable image registration,"
https://arxiv.org/abs/1711.01666
"""
[docs] def __init__(
self,
image_size: list[int],
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
depth: int,
out_kernel_initializer: str | None = "kaiming_uniform",
out_activation: str | None = None,
pooling: bool = True,
concat_skip: bool = False,
encode_kernel_sizes: int | list[int] = 3,
save_theta: bool = False,
):
"""
Args:
image_size: output displacement field spatial size
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
depth: input is at level 0, bottom is at level depth.
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
encode_kernel_sizes: kernel size for down-sampling
save_theta: whether to save the theta matrix estimation
"""
for size in image_size:
if size % (2**depth) != 0:
raise ValueError(
f"given depth {depth}, "
f"all input spatial dimension must be divisible by {2 ** depth}, "
f"got input of size {image_size}"
)
self.image_size = image_size
self.decode_size = [size // (2**depth) for size in image_size]
self.save_theta = save_theta
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channel_initial=num_channel_initial,
depth=depth,
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=spatial_dims,
pooling=pooling,
concat_skip=concat_skip,
encode_kernel_sizes=encode_kernel_sizes,
)
def build_output_block(self):
return AffineHead(
spatial_dims=self.spatial_dims,
image_size=self.image_size,
decode_size=self.decode_size,
in_channels=self.num_channels[-1],
save_theta=self.save_theta,
)
class AdditiveUpSampleBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
mode: str = "nearest",
align_corners: bool | None = None,
):
super().__init__()
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
self.mode = mode
self.align_corners = align_corners
def forward(self, x: torch.Tensor) -> torch.Tensor:
output_size = [size * 2 for size in x.shape[2:]]
deconved = self.deconv(x)
resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners)
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
out: torch.Tensor = deconved + resized
return out
[docs]class LocalNet(RegUNet):
"""
Reimplementation of LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""
[docs] def __init__(
self,
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
extract_levels: tuple[int],
out_kernel_initializer: str | None = "kaiming_uniform",
out_activation: str | None = None,
out_channels: int = 3,
pooling: bool = True,
use_additive_sampling: bool = True,
concat_skip: bool = False,
mode: str = "nearest",
align_corners: bool | None = None,
):
"""
Args:
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
use_additive_sampling: whether use additive up-sampling layer for decoding.
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
mode: mode for interpolation when use_additive_sampling, default is "nearest".
align_corners: align_corners for interpolation when use_additive_sampling, default is None.
"""
self.use_additive_upsampling = use_additive_sampling
self.mode = mode
self.align_corners = align_corners
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channel_initial=num_channel_initial,
extract_levels=extract_levels,
depth=max(extract_levels),
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=out_channels,
pooling=pooling,
concat_skip=concat_skip,
encode_kernel_sizes=[7] + [3] * max(extract_levels),
)
def build_bottom_block(self, in_channels: int, out_channels: int):
kernel_size = self.encode_kernel_sizes[self.depth]
return get_conv_block(
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
)
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
if self.use_additive_upsampling:
return AdditiveUpSampleBlock(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
mode=self.mode,
align_corners=self.align_corners,
)
return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)