Source code for monai.engines.multi_gpu_supervised_trainer

# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple

import torch
import torch.nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.optim.optimizer import Optimizer

from monai.engines.utils import get_devices_spec
from monai.utils import exact_version, optional_import

create_supervised_trainer, _ = optional_import("ignite.engine", "0.4.4", exact_version, "create_supervised_trainer")
create_supervised_evaluator, _ = optional_import("ignite.engine", "0.4.4", exact_version, "create_supervised_evaluator")
_prepare_batch, _ = optional_import("ignite.engine", "0.4.4", exact_version, "_prepare_batch")
if TYPE_CHECKING:
    from ignite.engine import Engine
    from ignite.metrics import Metric
else:
    Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
    Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric")

__all__ = [
    "create_multigpu_supervised_trainer",
    "create_multigpu_supervised_evaluator",
]


def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float:
    return loss.item()


def _default_eval_transform(
    x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    return y_pred, y


[docs]def create_multigpu_supervised_trainer( net: torch.nn.Module, optimizer: Optimizer, loss_fn: Callable, devices: Optional[Sequence[torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_transform, distributed: bool = False, ) -> Engine: """ Derived from `create_supervised_trainer` in Ignite. Factory function for creating a trainer for supervised models. Args: net: the network to train. optimizer: the optimizer to use. loss_fn: the loss function to use. devices: device(s) type specification (default: None). Applies to both model and batches. None is all devices used, empty list is CPU only. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use the first device as output device. Returns: Engine: a trainer engine with supervised update function. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is the loss of the processed batch by default. """ devices_ = get_devices_spec(devices) if distributed: net = DistributedDataParallel(net, device_ids=devices_) elif len(devices_) > 1: net = DataParallel(net) return create_supervised_trainer( net, optimizer, loss_fn, devices_[0], non_blocking, prepare_batch, output_transform )
[docs]def create_multigpu_supervised_evaluator( net: torch.nn.Module, metrics: Optional[Dict[str, Metric]] = None, devices: Optional[Sequence[torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_eval_transform, distributed: bool = False, ) -> Engine: """ Derived from `create_supervised_evaluator` in Ignite. Factory function for creating an evaluator for supervised models. Args: net: the model to train. metrics: a map of metric names to Metrics. devices: device(s) type specification (default: None). Applies to both model and batches. None is all devices used, empty list is CPU only. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use the first device as output device. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is a tuple of `(batch_pred, batch_y)` by default. Returns: Engine: an evaluator engine with supervised inference function. """ devices_ = get_devices_spec(devices) if distributed: net = DistributedDataParallel(net, device_ids=devices_) elif len(devices_) > 1: net = DataParallel(net) return create_supervised_evaluator(net, metrics, devices_[0], non_blocking, prepare_batch, output_transform)