Source code for monai.networks.layers.simplelayers

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Sequence, Union, cast

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function

from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.utils import ensure_tuple_rep, optional_import

_C, _ = optional_import("monai._C")

__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape"]


[docs]class SkipConnection(nn.Module): """ Concats the forward pass input with the result from the given submodule. """ def __init__(self, submodule, cat_dim: int = 1) -> None: super().__init__() self.submodule = submodule self.cat_dim = cat_dim def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([x, self.submodule(x)], self.cat_dim)
[docs]class Flatten(nn.Module): """ Flattens the given input in the forward pass to be [B,-1] in shape. """ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1)
class Reshape(nn.Module): """ Reshapes input tensors to the given shape (minus batch dimension), retaining original batch size. """ def __init__(self, *shape: int) -> None: """ Given a shape list/tuple `shape` of integers (s0, s1, ... , sn), this layer will reshape input tensors of shape (batch, s0 * s1 * ... * sn) to shape (batch, s0, s1, ... , sn). Args: shape: list/tuple of integer shape dimensions """ super().__init__() self.shape = (1,) + tuple(shape) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = list(self.shape) shape[0] = x.shape[0] # done this way for Torchscript return x.reshape(shape)
[docs]class GaussianFilter(nn.Module): def __init__(self, spatial_dims: int, sigma: Union[Sequence[float], float], truncated: float = 4.0) -> None: """ Args: spatial_dims: number of spatial dimensions of the input image. must have shape (Batch, channels, H[, W, ...]). sigma: std. truncated: spreads how many stds. """ super().__init__() self.spatial_dims = int(spatial_dims) _sigma = ensure_tuple_rep(sigma, self.spatial_dims) self.kernel = [ torch.nn.Parameter(torch.as_tensor(gaussian_1d(s, truncated), dtype=torch.float), False) for s in _sigma ] self.padding = [cast(int, (same_padding(k.size()[0]))) for k in self.kernel] self.conv_n = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] for idx, param in enumerate(self.kernel): self.register_parameter(f"kernel_{idx}", param)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape [Batch, chns, H, W, D]. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. """ if not torch.is_tensor(x): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") chns = x.shape[1] sp_dim = self.spatial_dims x = x.clone() # no inplace change of x def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: if d < 0: return input_ s = [1] * (sp_dim + 2) s[d + 2] = -1 kernel = self.kernel[d].reshape(s) kernel = kernel.repeat([chns, 1] + [1] * sp_dim) padding = [0] * sp_dim padding[d] = self.padding[d] return self.conv_n(input=_conv(input_, d - 1), weight=kernel, padding=padding, groups=chns) return _conv(x, sp_dim - 1)
class LLTMFunction(Function): @staticmethod def forward(ctx, input, weights, bias, old_h, old_cell): outputs = _C.lltm_forward(input, weights, bias, old_h, old_cell) new_h, new_cell = outputs[:2] variables = outputs[1:] + [weights] ctx.save_for_backward(*variables) return new_h, new_cell @staticmethod def backward(ctx, grad_h, grad_cell): outputs = _C.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs[:5] return d_input, d_weights, d_bias, d_old_h, d_old_cell
[docs]class LLTM(nn.Module): """ This recurrent unit is similar to an LSTM, but differs in that it lacks a forget gate and uses an Exponential Linear Unit (ELU) as its internal activation function. Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit. It has both C++ and CUDA implementation, automatically switch according to the target device where put this module to. Args: input_features: size of input feature data state_size: size of the state of recurrent unit Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html """ def __init__(self, input_features: int, state_size: int): super(LLTM, self).__init__() self.input_features = input_features self.state_size = state_size self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) self.bias = nn.Parameter(torch.empty(1, 3 * state_size)) self.reset_parameters() def reset_parameters(self): stdv = 1.0 / math.sqrt(self.state_size) for weight in self.parameters(): weight.data.uniform_(-stdv, +stdv) def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state)