Source code for monai.networks.layers.simplelayers

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Sequence, Union, cast

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function

from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import

_C, _ = optional_import("monai._C")
if not PT_BEFORE_1_7:
    fft, _ = optional_import("torch.fft")

__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering", "HilbertTransform"]


[docs]class SkipConnection(nn.Module): """ Combine the forward pass input with the result from the given submodule:: --+--submodule--o-- |_____________| The available modes are ``"cat"``, ``"add"``, ``"mul"``. """ def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat") -> None: """ Args: submodule: the module defines the trainable branch. dim: the dimension over which the tensors are concatenated. Used when mode is ``"cat"``. mode: ``"cat"``, ``"add"``, ``"mul"``. defaults to ``"cat"``. """ super().__init__() self.submodule = submodule self.dim = dim self.mode = SkipMode(mode).value
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.submodule(x) if self.mode == "cat": return torch.cat([x, y], dim=self.dim) if self.mode == "add": return torch.add(x, y) if self.mode == "mul": return torch.mul(x, y) raise NotImplementedError(f"Unsupported mode {self.mode}.")
[docs]class Flatten(nn.Module): """ Flattens the given input in the forward pass to be [B,-1] in shape. """
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1)
class Reshape(nn.Module): """ Reshapes input tensors to the given shape (minus batch dimension), retaining original batch size. """ def __init__(self, *shape: int) -> None: """ Given a shape list/tuple `shape` of integers (s0, s1, ... , sn), this layer will reshape input tensors of shape (batch, s0 * s1 * ... * sn) to shape (batch, s0, s1, ... , sn). Args: shape: list/tuple of integer shape dimensions """ super().__init__() self.shape = (1,) + tuple(shape) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = list(self.shape) shape[0] = x.shape[0] # done this way for Torchscript return x.reshape(shape) def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. Args: x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. """ if not torch.is_tensor(x): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") spatial_dims = len(x.shape) - 2 _kernels = [ torch.as_tensor(s, dtype=torch.float, device=s.device if torch.is_tensor(s) else None) for s in ensure_tuple_rep(kernels, spatial_dims) ] _paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels] n_chns = x.shape[1] def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: if d < 0: return input_ s = [1] * len(input_.shape) s[d + 2] = -1 _kernel = kernels[d].reshape(s) _kernel = _kernel.repeat([n_chns, 1] + [1] * spatial_dims) _padding = [0] * spatial_dims _padding[d] = _paddings[d] conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chns) return _conv(x, spatial_dims - 1)
[docs]class HilbertTransform(nn.Module): """ Determine the analytical signal of a Tensor along a particular axis. Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). Args: axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension). N: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. """ def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: if PT_BEFORE_1_7: raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) super().__init__() self.axis = axis self.n = n
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``. Returns: torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``. """ # Make input a real tensor x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) if torch.is_complex(x): raise ValueError("x must be real.") else: x = x.to(dtype=torch.float) if (self.axis < 0) or (self.axis > len(x.shape) - 1): raise ValueError("Invalid axis for shape of x.") n = x.shape[self.axis] if self.n is None else self.n if n <= 0: raise ValueError("N must be positive.") x = torch.as_tensor(x, dtype=torch.complex64) # Create frequency axis f = torch.cat( [ torch.true_divide(torch.arange(0, (n - 1) // 2 + 1, device=x.device), float(n)), torch.true_divide(torch.arange(-(n // 2), 0, device=x.device), float(n)), ] ) xf = fft.fft(x, n=n, dim=self.axis) # Create step function u = torch.heaviside(f, torch.tensor([0.5], device=f.device)) u = torch.as_tensor(u, dtype=x.dtype, device=u.device) new_dims_before = self.axis new_dims_after = len(xf.shape) - self.axis - 1 for _ in range(new_dims_before): u.unsqueeze_(0) for _ in range(new_dims_after): u.unsqueeze_(-1) ht = fft.ifft(xf * 2 * u, dim=self.axis) # Apply transform return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype)
[docs]class GaussianFilter(nn.Module): def __init__( self, spatial_dims: int, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor], truncated: float = 4.0, approx: str = "erf", requires_grad: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions of the input image. must have shape (Batch, channels, H[, W, ...]). sigma: std. could be a single value, or `spatial_dims` number of values. truncated: spreads how many stds. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". - ``erf`` approximation interpolates the error function; - ``sampled`` uses a sampled Gaussian kernel; - ``scalespace`` corresponds to https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel based on the modified Bessel functions. requires_grad: whether to store the gradients for sigma. if True, `sigma` will be the initial value of the parameters of this module (for example `parameters()` iterator could be used to get the parameters); otherwise this module will fix the kernels using `sigma` as the std. """ super().__init__() self.sigma = [ torch.nn.Parameter( torch.as_tensor(s, dtype=torch.float, device=s.device if torch.is_tensor(s) else None), requires_grad=requires_grad, ) for s in ensure_tuple_rep(sigma, int(spatial_dims)) ] self.truncated = truncated self.approx = approx for idx, param in enumerate(self.sigma): self.register_parameter(f"kernel_sigma_{idx}", param)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape [Batch, chns, H, W, D]. """ _kernel = [gaussian_1d(s, truncated=self.truncated, approx=self.approx) for s in self.sigma] return separable_filtering(x=x, kernels=_kernel)
class LLTMFunction(Function): @staticmethod def forward(ctx, input, weights, bias, old_h, old_cell): outputs = _C.lltm_forward(input, weights, bias, old_h, old_cell) new_h, new_cell = outputs[:2] variables = outputs[1:] + [weights] ctx.save_for_backward(*variables) return new_h, new_cell @staticmethod def backward(ctx, grad_h, grad_cell): outputs = _C.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs[:5] return d_input, d_weights, d_bias, d_old_h, d_old_cell
[docs]class LLTM(nn.Module): """ This recurrent unit is similar to an LSTM, but differs in that it lacks a forget gate and uses an Exponential Linear Unit (ELU) as its internal activation function. Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit. It has both C++ and CUDA implementation, automatically switch according to the target device where put this module to. Args: input_features: size of input feature data state_size: size of the state of recurrent unit Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html """ def __init__(self, input_features: int, state_size: int): super(LLTM, self).__init__() self.input_features = input_features self.state_size = state_size self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) self.bias = nn.Parameter(torch.empty(1, 3 * state_size)) self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.state_size) for weight in self.parameters(): weight.data.uniform_(-stdv, +stdv)
[docs] def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state)