Source code for monai.networks.blocks.downsample

from __future__ import annotations

from import Sequence

import torch
import torch.nn as nn

from monai.networks.layers.factories import Pool
from monai.utils import ensure_tuple_rep

[docs]class MaxAvgPool(nn.Module): """ Downsample with both maxpooling and avgpooling, double the channel size by concatenating the downsampled feature maps. """
[docs] def __init__( self, spatial_dims: int, kernel_size: Sequence[int] | int, stride: Sequence[int] | int | None = None, padding: Sequence[int] | int = 0, ceil_mode: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions of the input image. kernel_size: the kernel size of both pooling operations. stride: the stride of the window. Default value is `kernel_size`. padding: implicit zero padding to be added to both pooling operations. ceil_mode: when True, will use ceil instead of floor to compute the output shape. """ super().__init__() _params = { "kernel_size": ensure_tuple_rep(kernel_size, spatial_dims), "stride": None if stride is None else ensure_tuple_rep(stride, spatial_dims), "padding": ensure_tuple_rep(padding, spatial_dims), "ceil_mode": ceil_mode, } self.max_pool = Pool[Pool.MAX, spatial_dims](**_params) self.avg_pool = Pool[Pool.AVG, spatial_dims](**_params)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...]). Returns: Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]). """ return[self.max_pool(x), self.avg_pool(x)], dim=1)