# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
from monai.networks.layers.factories import Act, Norm, split_args
from monai.networks.nets.regressor import Regressor
__all__ = ["Classifier", "Discriminator", "Critic"]
[docs]class Classifier(Regressor):
"""
Defines a classification network from Regressor by specifying the output shape as a single dimensional tensor
with size equal to the number of classes to predict. The final activation function can also be specified, eg.
softmax or sigmoid.
Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
classes: integer stating the dimension of the final output tensor
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
last_act: name defining the last activation layer
"""
def __init__(
self,
in_shape: Sequence[int],
classes: int,
channels: Sequence[int],
strides: Sequence[int],
kernel_size: Sequence[int] | int = 3,
num_res_units: int = 2,
act=Act.PRELU,
norm=Norm.INSTANCE,
dropout: float | None = None,
bias: bool = True,
last_act: str | None = None,
) -> None:
super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias)
if last_act is not None:
last_act_name, last_act_args = split_args(last_act)
last_act_type = Act[last_act_name]
self.final.add_module("lastact", last_act_type(**last_act_args))
[docs]class Discriminator(Classifier):
"""
Defines a discriminator network from Classifier with a single output value and sigmoid activation by default. This
is meant for use with GANs or other applications requiring a generic discriminator network.
Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
last_act: name defining the last activation layer
"""
def __init__(
self,
in_shape: Sequence[int],
channels: Sequence[int],
strides: Sequence[int],
kernel_size: Sequence[int] | int = 3,
num_res_units: int = 2,
act=Act.PRELU,
norm=Norm.INSTANCE,
dropout: float | None = 0.25,
bias: bool = True,
last_act=Act.SIGMOID,
) -> None:
super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, last_act)
[docs]class Critic(Classifier):
"""
Defines a critic network from Classifier with a single output value and no final activation. The final layer is
`nn.Flatten` instead of `nn.Linear`, the final result is computed as the mean over the first dimension. This is
meant to be used with Wasserstein GANs.
Args:
in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
channels: tuple of integers stating the output channels of each convolutional layer
strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
kernel_size: integer or tuple of integers stating size of convolutional kernels
num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
act: name or type defining activation layers
norm: name or type defining normalization layers
dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
bias: boolean stating if convolution layers should have a bias component
"""
def __init__(
self,
in_shape: Sequence[int],
channels: Sequence[int],
strides: Sequence[int],
kernel_size: Sequence[int] | int = 3,
num_res_units: int = 2,
act=Act.PRELU,
norm=Norm.INSTANCE,
dropout: float | None = 0.25,
bias: bool = True,
) -> None:
super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None)
def _get_final_layer(self, in_shape: Sequence[int]):
return nn.Flatten()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
x = self.final(x)
x = x.mean(1)
return x.view((x.shape[0], -1))