Source code for monai.networks.blocks.downsample

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

from monai.networks.layers.factories import Pool


[docs]class MaxAvgPool(nn.Module): """ Downsample with both maxpooling and avgpooling, double the channel size by concatenating the downsampled feature maps. """ def __init__(self, spatial_dims: int, kernel_size, stride=None, padding=0, ceil_mode: bool = False): """ Args: spatial_dims: number of spatial dimensions of the input image. kernel_size: the kernel size of both pooling operations. stride: the stride of the window. Default value is `kernel_size`. padding: implicit zero padding to be added to both pooling operations. ceil_mode: when True, will use ceil instead of floor to compute the output shape. """ super().__init__() _params = {"kernel_size": kernel_size, "stride": stride, "padding": padding, "ceil_mode": ceil_mode} self.max_pool = Pool[Pool.MAX, spatial_dims](**_params) self.avg_pool = Pool[Pool.AVG, spatial_dims](**_params)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...]). Returns: Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]). """ x_d = torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1) return x_d