Source code for monai.networks.nets.dints

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import datetime
import warnings
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks.dints_block import (
    ActiConvNormBlock,
    FactorizedIncreaseBlock,
    FactorizedReduceBlock,
    P3DActiConvNormBlock,
)
from monai.networks.layers.factories import Conv
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import optional_import

# solving shortest path problem
csr_matrix, _ = optional_import("scipy.sparse", name="csr_matrix")
dijkstra, _ = optional_import("scipy.sparse.csgraph", name="dijkstra")

__all__ = ["DiNTS", "TopologyConstruction", "TopologyInstance", "TopologySearch"]


@torch.jit.interface
class CellInterface(torch.nn.Module):
    """interface for torchscriptable Cell"""

    def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor:  # type: ignore
        pass


@torch.jit.interface
class StemInterface(torch.nn.Module):
    """interface for torchscriptable Stem"""

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore
        pass


class StemTS(StemInterface):
    """wrapper for torchscriptable Stem"""

    def __init__(self, *mod):
        super().__init__()
        self.mod = torch.nn.Sequential(*mod)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mod(x)  # type: ignore


def _dfs(node, paths):
    """use depth first search to find all path activation combination"""
    if node == paths:
        return [[0], [1]]
    child = _dfs(node + 1, paths)
    return [[0] + _ for _ in child] + [[1] + _ for _ in child]


class _IdentityWithRAMCost(nn.Identity):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ram_cost = 0


class _ActiConvNormBlockWithRAMCost(ActiConvNormBlock):
    """The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated.
    Here is the estimation:
     feature_size = output_size/out_channel
     total_ram = ram_cost * output_size
     total_ram = in_channel * feature_size (activation map) +
                 in_channel * feature_size (convolution map) +
                 out_channel * feature_size (normalization)
               = (2*in_channel + out_channel) * output_size/out_channel
     ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1
    """

    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        kernel_size: int,
        padding: int,
        spatial_dims: int = 3,
        act_name: tuple | str = "RELU",
        norm_name: tuple | str = ("INSTANCE", {"affine": True}),
    ):
        super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name)
        self.ram_cost = 1 + in_channel / out_channel * 2


class _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock):
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        kernel_size: int,
        padding: int,
        p3dmode: int = 0,
        act_name: tuple | str = "RELU",
        norm_name: tuple | str = ("INSTANCE", {"affine": True}),
    ):
        super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name)
        # 1 in_channel (activation) + 1 in_channel (convolution) +
        # 1 out_channel (convolution) + 1 out_channel (normalization)
        self.ram_cost = 2 + 2 * in_channel / out_channel


class _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock):
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        spatial_dims: int = 3,
        act_name: tuple | str = "RELU",
        norm_name: tuple | str = ("INSTANCE", {"affine": True}),
    ):
        super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)
        # s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
        # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization)
        # s0 = output_size/out_channel
        self.ram_cost = 2 * in_channel / out_channel + 2


class _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock):
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        spatial_dims: int = 3,
        act_name: tuple | str = "RELU",
        norm_name: tuple | str = ("INSTANCE", {"affine": True}),
    ):
        super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)
        # s0 is upsampled 2x from s1, representing feature sizes at two resolutions.
        # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization)
        # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims)
        self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3


class MixedOp(nn.Module):
    """
    The weighted averaging of cell operations.
    Args:
        c: number of output channels.
        ops: a dictionary of operations. See also: ``Cell.OPS2D`` or ``Cell.OPS3D``.
        arch_code_c: binary cell operation code. It represents the operation results added to the output.
    """

    def __init__(self, c: int, ops: dict, arch_code_c=None):
        super().__init__()
        if arch_code_c is None:
            arch_code_c = np.ones(len(ops))
        self.ops = nn.ModuleList()
        for arch_c, op_name in zip(arch_code_c, ops):
            if arch_c > 0:
                self.ops.append(ops[op_name](c))

    def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor] = None):
        """
        Args:
            x: input tensor.
            weight: learnable architecture weights for cell operations. arch_code_c are derived from it.
        Return:
            out: weighted average of the operation results.
        """
        out = 0.0
        if weight is not None:
            weight = weight.to(x)
        for idx, _op in enumerate(self.ops):
            out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx]
        return out


class Cell(CellInterface):
    """
    The basic class for cell operation search, which contains a preprocessing operation and a mixed cell operation.
    Each cell is defined on a `path` in the topology search space.
    Args:
        c_prev: number of input channels
        c: number of output channels
        rate: resolution change rate. It represents the preprocessing operation before the mixed cell operation.
            ``-1`` for 2x downsample, ``1`` for 2x upsample, ``0`` for no change of resolution.
        arch_code_c: cell operation code
    """

    DIRECTIONS = 3
    # Possible output paths for `Cell`.
    #
    #       - UpSample
    #      /
    # +--+/
    # |  |--- Identity or AlignChannels
    # +--+\
    #      \
    #       - Downsample

    # Define 2D operation set, parameterized by the number of channels
    OPS2D = {
        "skip_connect": lambda _c: _IdentityWithRAMCost(),
        "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=2),
    }

    # Define 3D operation set, parameterized by the number of channels
    OPS3D = {
        "skip_connect": lambda _c: _IdentityWithRAMCost(),
        "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=3),
        "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=0),
        "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=1),
        "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2),
    }

    # Define connection operation set, parameterized by the number of channels
    ConnOPS = {
        "up": _FactorizedIncreaseBlockWithRAMCost,
        "down": _FactorizedReduceBlockWithRAMCost,
        "identity": _IdentityWithRAMCost,
        "align_channels": _ActiConvNormBlockWithRAMCost,
    }

    def __init__(
        self,
        c_prev: int,
        c: int,
        rate: int,
        arch_code_c=None,
        spatial_dims: int = 3,
        act_name: tuple | str = "RELU",
        norm_name: tuple | str = ("INSTANCE", {"affine": True}),
    ):
        super().__init__()
        self._spatial_dims = spatial_dims
        self._act_name = act_name
        self._norm_name = norm_name

        if rate == -1:  # downsample
            self.preprocess = self.ConnOPS["down"](
                c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
            )
        elif rate == 1:  # upsample
            self.preprocess = self.ConnOPS["up"](
                c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
            )
        else:
            if c_prev == c:
                self.preprocess = self.ConnOPS["identity"]()
            else:
                self.preprocess = self.ConnOPS["align_channels"](
                    c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name
                )

        # Define 2D operation set, parameterized by the number of channels
        self.OPS2D = {
            "skip_connect": lambda _c: _IdentityWithRAMCost(),
            "conv_3x3": lambda c: _ActiConvNormBlockWithRAMCost(
                c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name
            ),
        }

        # Define 3D operation set, parameterized by the number of channels
        self.OPS3D = {
            "skip_connect": lambda _c: _IdentityWithRAMCost(),
            "conv_3x3x3": lambda c: _ActiConvNormBlockWithRAMCost(
                c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name
            ),
            "conv_3x3x1": lambda c: _P3DActiConvNormBlockWithRAMCost(
                c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name
            ),
            "conv_3x1x3": lambda c: _P3DActiConvNormBlockWithRAMCost(
                c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name
            ),
            "conv_1x3x3": lambda c: _P3DActiConvNormBlockWithRAMCost(
                c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name
            ),
        }

        self.OPS = {}
        if self._spatial_dims == 2:
            self.OPS = self.OPS2D
        elif self._spatial_dims == 3:
            self.OPS = self.OPS3D
        else:
            raise NotImplementedError(f"Spatial dimensions {self._spatial_dims} is not supported.")

        self.op = MixedOp(c, self.OPS, arch_code_c)

    def forward(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor:
        """
        Args:
            x: input tensor
            weight: weights for different operations.
        """
        x = self.preprocess(x)
        x = self.op(x, weight)
        return x


[docs]class DiNTS(nn.Module): """ Reimplementation of DiNTS based on "DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.15954>". The model contains a pre-defined multi-resolution stem block (defined in this class) and a DiNTS space (defined in :py:class:`monai.networks.nets.TopologyInstance` and :py:class:`monai.networks.nets.TopologySearch`). The stem block is for: 1) input downsample and 2) output upsample to original size. The model downsamples the input image by 2 (if ``use_downsample=True``). The downsampled image is downsampled by [1, 2, 4, 8] times (``num_depths=4``) and used as input to the DiNTS search space (``TopologySearch``) or the DiNTS instance (``TopologyInstance``). - ``TopologyInstance`` is the final searched model. The initialization requires the searched architecture codes. - ``TopologySearch`` is a multi-path topology and cell operation search space. The architecture codes will be initialized as one. - ``TopologyConstruction`` is the parent class which constructs the instance and search space. To meet the requirements of the structure, the input size for each spatial dimension should be: divisible by 2 ** (num_depths + 1). Args: dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`. in_channels: number of input image channels. num_classes: number of output segmentation classes. act_name: activation name, default to 'RELU'. norm_name: normalization used in convolution blocks. Default to `InstanceNorm`. spatial_dims: spatial 2D or 3D inputs. use_downsample: use downsample in the stem. If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8], if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`. +1 for multi-resolution inputs. In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None. """ def __init__( self, dints_space, in_channels: int, num_classes: int, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), spatial_dims: int = 3, use_downsample: bool = True, node_a=None, ): super().__init__() self.dints_space = dints_space self.filter_nums = dints_space.filter_nums self.num_blocks = dints_space.num_blocks self.num_depths = dints_space.num_depths if spatial_dims not in (2, 3): raise NotImplementedError(f"Spatial dimensions {spatial_dims} is not supported.") self._spatial_dims = spatial_dims if node_a is None: self.node_a = torch.ones((self.num_blocks + 1, self.num_depths)) else: self.node_a = node_a # define stem operations for every block conv_type = Conv[Conv.CONV, spatial_dims] self.stem_down = nn.ModuleDict() self.stem_up = nn.ModuleDict() self.stem_finals = nn.Sequential( ActiConvNormBlock( self.filter_nums[0], self.filter_nums[0], act_name=act_name, norm_name=norm_name, spatial_dims=spatial_dims, ), conv_type( in_channels=self.filter_nums[0], out_channels=num_classes, kernel_size=1, stride=1, padding=0, groups=1, bias=True, dilation=1, ), ) mode = "trilinear" if self._spatial_dims == 3 else "bilinear" for res_idx in range(self.num_depths): # define downsample stems before DiNTS search if use_downsample: self.stem_down[str(res_idx)] = StemTS( nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1, ), get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), get_act_layer(name=act_name), conv_type( in_channels=self.filter_nums[res_idx], out_channels=self.filter_nums[res_idx + 1], kernel_size=3, stride=2, padding=1, groups=1, bias=False, dilation=1, ), get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx + 1]), ) self.stem_up[str(res_idx)] = StemTS( get_act_layer(name=act_name), conv_type( in_channels=self.filter_nums[res_idx + 1], out_channels=self.filter_nums[res_idx], kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1, ), get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), nn.Upsample(scale_factor=2, mode=mode, align_corners=True), ) else: self.stem_down[str(res_idx)] = StemTS( nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1, ), get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]), ) self.stem_up[str(res_idx)] = StemTS( get_act_layer(name=act_name), conv_type( in_channels=self.filter_nums[res_idx], out_channels=self.filter_nums[max(res_idx - 1, 0)], kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1, ), get_norm_layer( name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)] ), nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True), ) def weight_parameters(self): return [param for name, param in self.named_parameters()]
[docs] def forward(self, x: torch.Tensor): """ Prediction based on dynamic arch_code. Args: x: input tensor. """ inputs = [] for d in range(self.num_depths): # allow multi-resolution input _mod_w: StemInterface = self.stem_down[str(d)] x_out = _mod_w.forward(x) if self.node_a[0][d]: inputs.append(x_out) else: inputs.append(torch.zeros_like(x_out)) outputs = self.dints_space(inputs) blk_idx = self.num_blocks - 1 start = False _temp: torch.Tensor = torch.empty(0) for res_idx in range(self.num_depths - 1, -1, -1): _mod_up: StemInterface = self.stem_up[str(res_idx)] if start: _temp = _mod_up.forward(outputs[res_idx] + _temp) elif self.node_a[blk_idx + 1][res_idx]: start = True _temp = _mod_up.forward(outputs[res_idx]) prediction = self.stem_finals(_temp) return prediction
[docs]class TopologyConstruction(nn.Module): """ The base class for `TopologyInstance` and `TopologySearch`. Args: arch_code: `[arch_code_a, arch_code_c]`, numpy arrays. The architecture codes defining the model. For example, for a ``num_depths=4, num_blocks=12`` search space: - `arch_code_a` is a 12x10 (10 paths) binary matrix representing if a path is activated. - `arch_code_c` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used. - `arch_code` in ``__init__()`` is used for creating the network and remove unused network blocks. If None, all paths and cells operations will be used, and must be in the searching stage (is_search=True). channel_mul: adjust intermediate channel number, default is 1. cell: operation of each node. num_blocks: number of blocks (depth in the horizontal direction) of the DiNTS search space. num_depths: number of image resolutions of the DiNTS search space: 1, 1/2, 1/4 ... in each dimension. use_downsample: use downsample in the stem. If False, the search space will be in resolution [1, 1/2, 1/4, 1/8], if True, the search space will be in resolution [1/2, 1/4, 1/8, 1/16]. device: `'cpu'`, `'cuda'`, or device ID. Predefined variables: `filter_nums`: default to 32. Double the number of channels after downsample. topology related variables: - `arch_code2in`: path activation to its incoming node index (resolution). For depth = 4, arch_code2in = [0, 1, 0, 1, 2, 1, 2, 3, 2, 3]. The first path outputs from node 0 (top resolution), the second path outputs from node 1 (second resolution in the search space), the third path outputs from node 0, etc. - `arch_code2ops`: path activation to operations of upsample 1, keep 0, downsample -1. For depth = 4, arch_code2ops = [0, 1, -1, 0, 1, -1, 0, 1, -1, 0]. The first path does not change resolution, the second path perform upsample, the third perform downsample, etc. - `arch_code2out`: path activation to its output node index. For depth = 4, arch_code2out = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], the first and second paths connects to node 0 (top resolution), the 3,4,5 paths connects to node 1, etc. """ def __init__( self, arch_code: list | None = None, channel_mul: float = 1.0, cell=Cell, num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): super().__init__() n_feats = tuple([32 * (2**_i) for _i in range(num_depths + 1)]) self.filter_nums = [int(n_feat * channel_mul) for n_feat in n_feats] self.num_blocks = num_blocks self.num_depths = num_depths print( "{} - Length of input patch is recommended to be a multiple of {:d}.".format( datetime.datetime.now(), 2 ** (num_depths + int(use_downsample)) ) ) self._spatial_dims = spatial_dims self._act_name = act_name self._norm_name = norm_name self.use_downsample = use_downsample self.device = device self.num_cell_ops = 0 if self._spatial_dims == 2: self.num_cell_ops = len(cell.OPS2D) elif self._spatial_dims == 3: self.num_cell_ops = len(cell.OPS3D) # Calculate predefined parameters for topology search and decoding arch_code2in, arch_code2out = [], [] for i in range(Cell.DIRECTIONS * self.num_depths - 2): arch_code2in.append((i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS) arch_code2ops = ([-1, 0, 1] * self.num_depths)[1:-1] for m in range(self.num_depths): arch_code2out.extend([m, m, m]) arch_code2out = arch_code2out[1:-1] self.arch_code2in = arch_code2in self.arch_code2ops = arch_code2ops self.arch_code2out = arch_code2out # define NAS search space if arch_code is None: arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to(self.device) arch_code_c = torch.ones((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)).to(self.device) else: arch_code_a = torch.from_numpy(arch_code[0]).to(self.device) arch_code_c = F.one_hot(torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops).to(self.device) self.arch_code_a = arch_code_a self.arch_code_c = arch_code_c # define cell operation on each path self.cell_tree = nn.ModuleDict() for blk_idx in range(self.num_blocks): for res_idx in range(len(self.arch_code2out)): if self.arch_code_a[blk_idx, res_idx] == 1: self.cell_tree[str((blk_idx, res_idx))] = cell( self.filter_nums[self.arch_code2in[res_idx] + int(use_downsample)], self.filter_nums[self.arch_code2out[res_idx] + int(use_downsample)], self.arch_code2ops[res_idx], self.arch_code_c[blk_idx, res_idx], self._spatial_dims, self._act_name, self._norm_name, )
[docs] def forward(self, x): """This function to be implemented by the architecture instances or search spaces.""" pass
[docs]class TopologyInstance(TopologyConstruction): """ Instance of the final searched architecture. Only used in re-training/inference stage. """
[docs] def __init__( self, arch_code=None, channel_mul: float = 1.0, cell=Cell, num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): """ Initialize DiNTS topology search space of neural architectures. """ if arch_code is None: warnings.warn("arch_code not provided when not searching.") super().__init__( arch_code=arch_code, channel_mul=channel_mul, cell=cell, num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, act_name=act_name, norm_name=norm_name, use_downsample=use_downsample, device=device, )
[docs] def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: """ Args: x: input tensor. """ # generate path activation probability inputs = x for blk_idx in range(self.num_blocks): outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data): if activation: mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))] _out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None) outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out inputs = outputs return inputs
[docs]class TopologySearch(TopologyConstruction): """ DiNTS topology search space of neural architectures. Examples: .. code-block:: python from monai.networks.nets.dints import TopologySearch topology_search_space = TopologySearch( channel_mul=0.5, num_blocks=8, num_depths=4, use_downsample=True, spatial_dims=3) topology_search_space.get_ram_cost_usage(in_size=(2, 16, 80, 80, 80), full=True) multi_res_images = [ torch.randn(2, 16, 80, 80, 80), torch.randn(2, 32, 40, 40, 40), torch.randn(2, 64, 20, 20, 20), torch.randn(2, 128, 10, 10, 10)] prediction = topology_search_space(image) for x in prediction: print(x.shape) # torch.Size([2, 16, 80, 80, 80]) # torch.Size([2, 32, 40, 40, 40]) # torch.Size([2, 64, 20, 20, 20]) # torch.Size([2, 128, 10, 10, 10]) Class method overview: - ``get_prob_a()``: convert learnable architecture weights to path activation probabilities. - ``get_ram_cost_usage()``: get estimated ram cost. - ``get_topology_entropy()``: get topology entropy loss in searching stage. - ``decode()``: get final binarized architecture code. - ``gen_mtx()``: generate variables needed for topology search. Predefined variables: - `tidx`: index used to convert path activation matrix T = (depth,depth) in transfer_mtx to path activation arch_code (1,3*depth-2), for depth = 4, tidx = [0, 1, 4, 5, 6, 9, 10, 11, 14, 15], A tidx (10 binary values) represents the path activation. - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern. It is used to convert path activation pattern (1, paths) to node activation (1, nodes) - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths). - `all_connect`: All possible path activations. For depth = 4, all_connection has 1024 vectors of length 10 (10 paths). The return value will exclude path activation of all 0. """ node2out: list[list] node2in: list[list]
[docs] def __init__( self, channel_mul: float = 1.0, cell=Cell, arch_code: list | None = None, num_blocks: int = 6, num_depths: int = 3, spatial_dims: int = 3, act_name: tuple | str = "RELU", norm_name: tuple | str = ("INSTANCE", {"affine": True}), use_downsample: bool = True, device: str = "cpu", ): """ Initialize DiNTS topology search space of neural architectures. """ super().__init__( arch_code=arch_code, channel_mul=channel_mul, cell=cell, num_blocks=num_blocks, num_depths=num_depths, spatial_dims=spatial_dims, act_name=act_name, norm_name=norm_name, use_downsample=use_downsample, device=device, ) tidx = [] _d = Cell.DIRECTIONS for i in range(_d * self.num_depths - 2): tidx.append((i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d) self.tidx = tidx transfer_mtx, node_act_list, child_list = self.gen_mtx(num_depths) self.node_act_list = np.asarray(node_act_list) self.node_act_dict = {str(self.node_act_list[i]): i for i in range(len(self.node_act_list))} self.transfer_mtx = transfer_mtx self.child_list = np.asarray(child_list) self.ram_cost = np.zeros((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)) for blk_idx in range(self.num_blocks): for res_idx in range(len(self.arch_code2out)): if self.arch_code_a[blk_idx, res_idx] == 1: self.ram_cost[blk_idx, res_idx] = np.array( [ op.ram_cost + self.cell_tree[str((blk_idx, res_idx))].preprocess.ram_cost for op in self.cell_tree[str((blk_idx, res_idx))].op.ops[: self.num_cell_ops] ] ) # define cell and macro architecture probabilities self.log_alpha_c = nn.Parameter( torch.zeros(self.num_blocks, len(self.arch_code2out), self.num_cell_ops) .normal_(1, 0.01) .to(self.device) .requires_grad_() ) self.log_alpha_a = nn.Parameter( torch.zeros(self.num_blocks, len(self.arch_code2out)).normal_(0, 0.01).to(self.device).requires_grad_() ) self._arch_param_names = ["log_alpha_a", "log_alpha_c"]
[docs] def gen_mtx(self, depth: int): """ Generate elements needed in decoding and topology. - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern. It is used to convert path activation pattern (1, paths) to node activation (1, nodes) - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths). - `all_connect`: All possible path activations. For depth = 4, all_connection has 1024 vectors of length 10 (10 paths). The return value will exclude path activation of all 0. """ # total paths in a block, each node has three output paths, # except the two nodes at the top and the bottom scales paths = Cell.DIRECTIONS * depth - 2 # for 10 paths, all_connect has 1024 possible path activations. [1 0 0 0 0 0 0 0 0 0] means the top # path is activated. all_connect = _dfs(0, paths - 1) # Save all possible connections in mtx (might be redundant and infeasible) mtx = [] for m in all_connect: # convert path activation [1,paths] to path activation matrix [depth, depth] ma = np.zeros((depth, depth)) for i in range(paths): ma[(i + 1) // Cell.DIRECTIONS, (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS] = m[i] mtx.append(ma) # define all possible node activation node_act_list = _dfs(0, depth - 1)[1:] transfer_mtx = {} for arch_code in node_act_list: # make sure each activated node has an active connection, inactivated node has no connection arch_code_mtx = [_ for _ in mtx if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all()] transfer_mtx[str(np.array(arch_code))] = arch_code_mtx return transfer_mtx, node_act_list, all_connect[1:]
def weight_parameters(self): return [param for name, param in self.named_parameters() if name not in self._arch_param_names]
[docs] def get_prob_a(self, child: bool = False): """ Get final path and child model probabilities from architecture weights `log_alpha_a`. This is used in forward pass, getting training loss, and final decoding. Args: child: return child probability (used in decoding) Return: arch_code_prob_a: the path activation probability of size: `[number of blocks, number of paths in each block]`. For 12 blocks, 4 depths search space, the size is [12,10] probs_a: The probability of all child models (size 1023x10). Each child model is a path activation pattern (1D vector of length 10 for 10 paths). In total 1023 child models (2^10 -1) """ _arch_code_prob_a = torch.sigmoid(self.log_alpha_a) # remove the case where all path are zero, and re-normalize. norm = 1 - (1 - _arch_code_prob_a).prod(-1) arch_code_prob_a = _arch_code_prob_a / norm.unsqueeze(1) if child: path_activation = torch.from_numpy(self.child_list).to(self.device) probs_a = [ ( path_activation * _arch_code_prob_a[blk_idx] + (1 - path_activation) * (1 - _arch_code_prob_a[blk_idx]) ).prod(-1) / norm[blk_idx] for blk_idx in range(self.num_blocks) ] probs_a = torch.stack(probs_a) # type: ignore return probs_a, arch_code_prob_a return None, arch_code_prob_a
[docs] def get_ram_cost_usage(self, in_size, full: bool = False): """ Get estimated output tensor size to approximate RAM consumption. Args: in_size: input image shape (4D/5D, ``[BCHW[D]]``) at the highest resolution level. full: full ram cost usage with all probability of 1. """ # convert input image size to feature map size at each level batch_size = in_size[0] image_size = np.array(in_size[-self._spatial_dims :]) sizes = [] for res_idx in range(self.num_depths): sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) if full: arch_code_prob_a = arch_code_prob_a.detach() arch_code_prob_a.fill_(1) ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device) usage = 0.0 for blk_idx in range(self.num_blocks): # node activation for input # cell operation for path_idx in range(len(self.arch_code2out)): usage += ( arch_code_prob_a[blk_idx, path_idx] * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) * sizes[self.arch_code2out[path_idx]] ) return usage * 32 / 8 / 1024**2
[docs] def get_topology_entropy(self, probs): """ Get topology entropy loss at searching stage. Args: probs: path activation probabilities """ if hasattr(self, "node2in"): node2in = self.node2in # pylint: disable=E0203 node2out = self.node2out # pylint: disable=E0203 else: # node activation index to feasible input child_idx node2in = [[] for _ in range(len(self.node_act_list))] # node activation index to feasible output child_idx node2out = [[] for _ in range(len(self.node_act_list))] for child_idx in range(len(self.child_list)): _node_in, _node_out = np.zeros(self.num_depths), np.zeros(self.num_depths) for res_idx in range(len(self.arch_code2out)): _node_out[self.arch_code2out[res_idx]] += self.child_list[child_idx][res_idx] _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][res_idx] _node_in = (_node_in >= 1).astype(int) _node_out = (_node_out >= 1).astype(int) node2in[self.node_act_dict[str(_node_out)]].append(child_idx) node2out[self.node_act_dict[str(_node_in)]].append(child_idx) self.node2in = node2in self.node2out = node2out # calculate entropy ent = 0 for blk_idx in range(self.num_blocks - 1): blk_ent = 0 # node activation probability for node_idx in range(len(self.node_act_list)): _node_p = probs[blk_idx, node2in[node_idx]].sum() _out_probs = probs[blk_idx + 1, node2out[node_idx]].sum() blk_ent += -(_node_p * torch.log(_out_probs + 1e-5) + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5)) ent += blk_ent return ent
[docs] def decode(self): """ Decode network log_alpha_a/log_alpha_c using dijkstra shortest path algorithm. `[node_a, arch_code_a, arch_code_c, arch_code_a_max]` is decoded when using ``self.decode()``. For example, for a ``num_depths=4``, ``num_blocks=12`` search space: - ``node_a`` is a 4x13 binary matrix representing if a feature node is activated (13 because of multi-resolution inputs). - ``arch_code_a`` is a 12x10 (10 paths) binary matrix representing if a path is activated. - ``arch_code_c`` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used. Return: arch_code with maximum probability """ probs, arch_code_prob_a = self.get_prob_a(child=True) arch_code_a_max = self.child_list[torch.argmax(probs, -1).data.cpu().numpy()] arch_code_c = torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy() probs = probs.data.cpu().numpy() # define adjacency matrix amtx = np.zeros( (1 + len(self.child_list) * self.num_blocks + 1, 1 + len(self.child_list) * self.num_blocks + 1) ) # build a path activation to child index searching dictionary path2child = {str(self.child_list[i]): i for i in range(len(self.child_list))} # build a submodel to submodel index sub_amtx = np.zeros((len(self.child_list), len(self.child_list))) for child_idx in range(len(self.child_list)): _node_act = np.zeros(self.num_depths).astype(int) for path_idx in range(len(self.child_list[child_idx])): _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][path_idx] _node_act = (_node_act >= 1).astype(int) for mtx in self.transfer_mtx[str(_node_act)]: connect_child_idx = path2child[str(mtx.flatten()[self.tidx].astype(int))] sub_amtx[child_idx, connect_child_idx] = 1 # fill in source to first block, add 1e-5/1e-3 to avoid log0 and negative edge weights amtx[0, 1 : 1 + len(self.child_list)] = -np.log(probs[0] + 1e-5) + 0.001 # fill in the rest blocks for blk_idx in range(1, self.num_blocks): amtx[ 1 + (blk_idx - 1) * len(self.child_list) : 1 + blk_idx * len(self.child_list), 1 + blk_idx * len(self.child_list) : 1 + (blk_idx + 1) * len(self.child_list), ] = sub_amtx * np.tile(-np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1)) # fill in the last to the sink amtx[1 + (self.num_blocks - 1) * len(self.child_list) : 1 + self.num_blocks * len(self.child_list), -1] = 0.001 graph = csr_matrix(amtx) dist_matrix, predecessors, sources = dijkstra( csgraph=graph, directed=True, indices=0, min_only=True, return_predecessors=True ) index, a_idx = -1, -1 arch_code_a = np.zeros((self.num_blocks, len(self.arch_code2out))) node_a = np.zeros((self.num_blocks + 1, self.num_depths)) # decoding to paths while True: index = predecessors[index] if index == 0: break child_idx = (index - 1) % len(self.child_list) arch_code_a[a_idx, :] = self.child_list[child_idx] for res_idx in range(len(self.arch_code2out)): node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[a_idx, res_idx] a_idx -= 1 for res_idx in range(len(self.arch_code2out)): node_a[a_idx, self.arch_code2in[res_idx]] += arch_code_a[0, res_idx] node_a = (node_a >= 1).astype(int) return node_a, arch_code_a, arch_code_c, arch_code_a_max
[docs] def forward(self, x): """ Prediction based on dynamic arch_code. Args: x: a list of `num_depths` input tensors as a multi-resolution input. tensor is of shape `BCHW[D]` where `C` must match `self.filter_nums`. """ # generate path activation probability probs_a, arch_code_prob_a = self.get_prob_a(child=False) inputs = x for blk_idx in range(self.num_blocks): outputs = [0.0] * self.num_depths for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data.cpu().numpy()): if activation: _w = F.softmax(self.log_alpha_c[blk_idx, res_idx], dim=-1) outputs[self.arch_code2out[res_idx]] += ( self.cell_tree[str((blk_idx, res_idx))](inputs[self.arch_code2in[res_idx]], weight=_w) * arch_code_prob_a[blk_idx, res_idx] ) inputs = outputs return inputs