# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable
from typing import Union
import numpy as np
import torch
import torch.nn as nn
from monai.networks.blocks.upsample import UpSample
from monai.networks.layers.factories import Act, Conv, Norm, split_args
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import UpsampleMode, has_option
__all__ = ["SegResNetDS"]
def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
"""
A helper function to compute a schedule of scale at different downsampling levels,
given the input resolution.
.. code-block:: python
scales_for_resolution(resolution=[1,1,5], n_stages=5)
Args:
resolution: input image resolution (in mm)
n_stages: optionally the number of stages of the network
"""
ndim = len(resolution)
res = np.array(resolution)
if not all(res > 0):
raise ValueError("Resolution must be positive")
nl = np.floor(np.log2(np.max(res) / res)).astype(np.int32)
scales = [tuple(np.where(2**i >= 2**nl, 1, 2)) for i in range(max(nl))]
if n_stages and n_stages > max(nl):
scales = scales + [(2,) * ndim] * (n_stages - max(nl))
else:
scales = scales[:n_stages]
return scales
def aniso_kernel(scale: tuple | list):
"""
A helper function to compute kernel_size, padding and stride for the given scale
Args:
scale: scale from a current scale level
"""
kernel_size = [3 if scale[k] > 1 else 1 for k in range(len(scale))]
padding = [k // 2 for k in kernel_size]
return kernel_size, padding, scale
class SegResBlock(nn.Module):
"""
Residual network block used SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
<https://arxiv.org/pdf/1810.11654.pdf>`_.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
norm: tuple | str,
kernel_size: tuple | int = 3,
act: tuple | str = "relu",
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
in_channels: number of input channels.
norm: feature normalization type and arguments.
kernel_size: convolution kernel size. Defaults to 3.
act: activation type and arguments. Defaults to ``RELU``.
"""
super().__init__()
if isinstance(kernel_size, (tuple, list)):
padding = tuple(k // 2 for k in kernel_size)
else:
padding = kernel_size // 2 # type: ignore
self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
self.act1 = get_act_layer(act)
self.conv1 = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=False,
)
self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
self.act2 = get_act_layer(act)
self.conv2 = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=False,
)
def forward(self, x):
identity = x
x = self.conv2(self.act2(self.norm2(self.conv1(self.act1(self.norm1(x))))))
x += identity
return x
class SegResEncoder(nn.Module):
"""
SegResEncoder based on the encoder structure in `3D MRI brain tumor segmentation using autoencoder regularization
<https://arxiv.org/pdf/1810.11654.pdf>`_.
Args:
spatial_dims: spatial dimension of the input data. Defaults to 3.
init_filters: number of output channels for initial convolution layer. Defaults to 32.
in_channels: number of input channels for the network. Defaults to 1.
out_channels: number of output channels for the network. Defaults to 2.
act: activation type and arguments. Defaults to ``RELU``.
norm: feature normalization type and arguments. Defaults to ``BATCH``.
blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
head_module: optional callable module to apply to the final features.
anisotropic_scales: optional list of scale for each scale level.
"""
def __init__(
self,
spatial_dims: int = 3,
init_filters: int = 32,
in_channels: int = 1,
act: tuple | str = "relu",
norm: tuple | str = "batch",
blocks_down: tuple = (1, 2, 2, 4),
head_module: nn.Module | None = None,
anisotropic_scales: tuple | None = None,
):
super().__init__()
if spatial_dims not in (1, 2, 3):
raise ValueError("`spatial_dims` can only be 1, 2 or 3.")
# ensure normalization has affine trainable parameters (if not specified)
norm = split_args(norm)
if has_option(Norm[norm[0], spatial_dims], "affine"):
norm[1].setdefault("affine", True) # type: ignore
# ensure activation is inplace (if not specified)
act = split_args(act)
if has_option(Act[act[0]], "inplace"):
act[1].setdefault("inplace", True) # type: ignore
filters = init_filters # base number of features
kernel_size, padding, _ = aniso_kernel(anisotropic_scales[0]) if anisotropic_scales else (3, 1, 1)
self.conv_init = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels,
out_channels=filters,
kernel_size=kernel_size,
padding=padding,
stride=1,
bias=False,
)
self.layers = nn.ModuleList()
for i in range(len(blocks_down)):
level = nn.ModuleDict()
kernel_size, padding, stride = aniso_kernel(anisotropic_scales[i]) if anisotropic_scales else (3, 1, 2)
blocks = [
SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act)
for _ in range(blocks_down[i])
]
level["blocks"] = nn.Sequential(*blocks)
if i < len(blocks_down) - 1:
level["downsample"] = Conv[Conv.CONV, spatial_dims](
in_channels=filters,
out_channels=2 * filters,
bias=False,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
else:
level["downsample"] = nn.Identity()
self.layers.append(level)
filters *= 2
self.head_module = head_module
self.in_channels = in_channels
self.blocks_down = blocks_down
self.init_filters = init_filters
self.norm = norm
self.act = act
self.spatial_dims = spatial_dims
def _forward(self, x: torch.Tensor) -> list[torch.Tensor]:
outputs = []
x = self.conv_init(x)
for level in self.layers:
x = level["blocks"](x)
outputs.append(x)
x = level["downsample"](x)
if self.head_module is not None:
outputs = self.head_module(outputs)
return outputs
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
return self._forward(x)
[docs]
class SegResNetDS(nn.Module):
"""
SegResNetDS based on `3D MRI brain tumor segmentation using autoencoder regularization
<https://arxiv.org/pdf/1810.11654.pdf>`_.
It is similar to https://docs.monai.io/en/stable/networks.html#segresnet, with several
improvements including deep supervision and non-isotropic kernel support.
Args:
spatial_dims: spatial dimension of the input data. Defaults to 3.
init_filters: number of output channels for initial convolution layer. Defaults to 32.
in_channels: number of input channels for the network. Defaults to 1.
out_channels: number of output channels for the network. Defaults to 2.
act: activation type and arguments. Defaults to ``RELU``.
norm: feature normalization type and arguments. Defaults to ``BATCH``.
blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
blocks_up: number of upsample blocks (optional).
dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
At dsdepth==1,only a single output is returned.
preprocess: optional callable function to apply before the model's forward pass
resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
image spacing into an approximately isotropic space.
Otherwise, by default, the kernel size and downsampling is always isotropic.
"""
def __init__(
self,
spatial_dims: int = 3,
init_filters: int = 32,
in_channels: int = 1,
out_channels: int = 2,
act: tuple | str = "relu",
norm: tuple | str = "batch",
blocks_down: tuple = (1, 2, 2, 4),
blocks_up: tuple | None = None,
dsdepth: int = 1,
preprocess: nn.Module | Callable | None = None,
upsample_mode: UpsampleMode | str = "deconv",
resolution: tuple | None = None,
):
super().__init__()
if spatial_dims not in (1, 2, 3):
raise ValueError("`spatial_dims` can only be 1, 2 or 3.")
self.spatial_dims = spatial_dims
self.init_filters = init_filters
self.in_channels = in_channels
self.out_channels = out_channels
self.act = act
self.norm = norm
self.blocks_down = blocks_down
self.dsdepth = max(dsdepth, 1)
self.resolution = resolution
self.preprocess = preprocess
if resolution is not None:
if not isinstance(resolution, (list, tuple)):
raise TypeError("resolution must be a tuple")
elif not all(r > 0 for r in resolution):
raise ValueError("resolution must be positive")
# ensure normalization had affine trainable parameters (if not specified)
norm = split_args(norm)
if has_option(Norm[norm[0], spatial_dims], "affine"):
norm[1].setdefault("affine", True) # type: ignore
# ensure activation is inplace (if not specified)
act = split_args(act)
if has_option(Act[act[0]], "inplace"):
act[1].setdefault("inplace", True) # type: ignore
anisotropic_scales = None
if resolution:
anisotropic_scales = scales_for_resolution(resolution, n_stages=len(blocks_down))
self.anisotropic_scales = anisotropic_scales
self.encoder = SegResEncoder(
spatial_dims=spatial_dims,
init_filters=init_filters,
in_channels=in_channels,
act=act,
norm=norm,
blocks_down=blocks_down,
anisotropic_scales=anisotropic_scales,
)
n_up = len(blocks_down) - 1
if blocks_up is None:
blocks_up = (1,) * n_up # assume 1 upsample block per level
self.blocks_up = blocks_up
filters = init_filters * 2**n_up
self.up_layers = nn.ModuleList()
for i in range(n_up):
filters = filters // 2
kernel_size, _, stride = (
aniso_kernel(anisotropic_scales[len(blocks_up) - i - 1]) if anisotropic_scales else (3, 1, 2)
)
level = nn.ModuleDict()
level["upsample"] = UpSample(
mode=upsample_mode,
spatial_dims=spatial_dims,
in_channels=2 * filters,
out_channels=filters,
kernel_size=kernel_size,
scale_factor=stride,
bias=False,
align_corners=False,
)
blocks = [
SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act)
for _ in range(blocks_up[i])
]
level["blocks"] = nn.Sequential(*blocks)
if len(blocks_up) - i <= dsdepth: # deep supervision heads
level["head"] = Conv[Conv.CONV, spatial_dims](
in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True
)
else:
level["head"] = nn.Identity()
self.up_layers.append(level)
if n_up == 0: # in a corner case of flat structure (no downsampling), attache a single head
level = nn.ModuleDict(
{
"upsample": nn.Identity(),
"blocks": nn.Identity(),
"head": Conv[Conv.CONV, spatial_dims](
in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True
),
}
)
self.up_layers.append(level)
[docs]
def shape_factor(self):
"""
Calculate the factors (divisors) that the input image shape must be divisible by
"""
if self.anisotropic_scales is None:
d = [2 ** (len(self.blocks_down) - 1)] * self.spatial_dims
else:
d = list(np.prod(np.array(self.anisotropic_scales[:-1]), axis=0))
return d
[docs]
def is_valid_shape(self, x):
"""
Calculate if the input shape is divisible by the minimum factors for the current network configuration
"""
a = [i % j == 0 for i, j in zip(x.shape[2:], self.shape_factor())]
return all(a)
def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
if self.preprocess is not None:
x = self.preprocess(x)
if not self.is_valid_shape(x):
raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")
x_down = self.encoder(x)
x_down.reverse()
x = x_down.pop(0)
if len(x_down) == 0:
x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]
outputs: list[torch.Tensor] = []
i = 0
for level in self.up_layers:
x = level["upsample"](x)
x += x_down.pop(0)
x = level["blocks"](x)
if len(self.up_layers) - i <= self.dsdepth:
outputs.append(level["head"](x))
i = i + 1
outputs.reverse()
# in eval() mode, always return a single final output
if not self.training or len(outputs) == 1:
return outputs[0]
# return a list of DS outputs
return outputs
[docs]
def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
return self._forward(x)