Source code for monai.networks.nets.classifier

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch.nn as nn
from monai.networks.layers.factories import Norm, Act, split_args
from monai.networks.nets.regressor import Regressor


[docs]class Classifier(Regressor): """ Defines a classification network from Regressor by specifying the output shape as a single dimensional tensor with size equal to the number of classes to predict. The final activation function can also be specified, eg. softmax or sigmoid. """ def __init__( self, in_shape, classes, channels, strides, kernel_size=3, num_res_units=2, act=Act.PRELU, norm=Norm.INSTANCE, dropout=None, bias=True, last_act=None, ): super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias) if last_act is not None: last_act_name, last_act_args = split_args(last_act) last_act_type = Act[last_act_name] self.final.add_module("lastact", last_act_type(**last_act_args))
[docs]class Discriminator(Classifier): """ Defines a discriminator network from Classifier with a single output value and sigmoid activation by default. This is meant for use with GANs or other applications requiring a generic discriminator network. """ def __init__( self, in_shape, channels, strides, kernel_size=3, num_res_units=2, act=Act.PRELU, norm=Norm.INSTANCE, dropout=0.25, bias=True, last_act=Act.SIGMOID, ): super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, last_act)
[docs]class Critic(Classifier): """ Defines a critic network from Classifier with a single output value and no final activation. The final layer is `nn.Flatten` instead of `nn.Linear`, the final result is computed as the mean over the first dimension. This is meant to be used with Wassertein GANs. """ def __init__( self, in_shape, channels, strides, kernel_size=3, num_res_units=2, act=Act.PRELU, norm=Norm.INSTANCE, dropout=0.25, bias=True, ): super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None) def _get_final_layer(self, in_shape): return nn.Flatten()
[docs] def forward(self, x): x = self.net(x) x = self.final(x) x = x.mean(1) return x.view((x.shape[0], -1))