Engines

Multi-GPU data parallel

monai.engines.multi_gpu_supervised_trainer.create_multigpu_supervised_evaluator(net, metrics=None, devices=None, non_blocking=False, prepare_batch=<function _prepare_batch>, output_transform=<function _default_eval_transform>)[source]

Derived from create_supervised_evaluator in Ignite.

Factory function for creating an evaluator for supervised models.

Parameters
  • net (torch.nn.Module) – the model to train.

  • metrics (dict of str - Metric) – a map of metric names to Metrics.

  • devices (list, optional) – device(s) type specification (default: None). Applies to both model and batches. None is all devices used, empty list is CPU only.

  • non_blocking (bool) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (Callable) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

  • output_transform (Callable) – function that receives ‘x’, ‘y’, ‘y_pred’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning (y_pred, y,) which fits output expected by metrics. If you change it you should use output_transform in metrics.

Note

engine.state.output for this engine is defined by output_transform parameter and is a tuple of (batch_pred, batch_y) by default.

Returns

an evaluator engine with supervised inference function.

Return type

Engine

monai.engines.multi_gpu_supervised_trainer.create_multigpu_supervised_trainer(net, optimizer, loss_fn, devices=None, non_blocking=False, prepare_batch=<function _prepare_batch>, output_transform=<function _default_transform>)[source]

Derived from create_supervised_trainer in Ignite.

Factory function for creating a trainer for supervised models.

Parameters
  • net (torch.nn.Module) – the network to train.

  • optimizer (torch.optim.Optimizer) – the optimizer to use.

  • loss_fn (torch.nn loss function) – the loss function to use.

  • devices (list, optional) – device(s) type specification (default: None). Applies to both model and batches. None is all devices used, empty list is CPU only.

  • non_blocking (bool) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (Callable) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

  • output_transform (Callable) – function that receives ‘x’, ‘y’, ‘y_pred’, ‘loss’ and returns value to be assigned to engine’s state.output after each iteration. Default is returning loss.item().

Returns

a trainer engine with supervised update function.

Return type

Engine

Note

engine.state.output for this engine is defined by output_transform parameter and is the loss of the processed batch by default.

Workflows

Workflow

class monai.engines.workflow.Workflow(device, max_epochs, amp, data_loader, prepare_batch=<function default_prepare_batch>, iteration_update=None, post_transform=None, key_metric=None, additional_metrics=None, handlers=None)[source]

Workflow defines the core work process inheriting from Ignite engine. All trainer, validator and evaluator share this same workflow as base class, because they all can be treated as same Ignite engine loops. It initializes all the sharable data in Ignite engine.state. And attach additional processing logics to Ignite engine based on Event-Handler mechanism.

Users should consider to inherit from trainer or evaluator to develop more trainers or evaluators.

Parameters
  • device (torch.device) – an object representing the device on which to run.

  • max_epochs (int) – the total epoch number for engine to run, validator and evaluator have only 1 epoch.

  • amp (bool) – whether to enable auto-mixed-precision training, reserved.

  • data_loader (torch.DataLoader) – Ignite engine use data_loader to run, must be torch.DataLoader.

  • prepare_batch (Callable) – function to parse image and label for every iteration.

  • iteration_update (Optional[Callable]) – the callable function for every iteration, expect to accept engine and batchdata as input parameters. if not provided, use self._iteration() instead.

  • post_transform (Transform) – execute additional transformation for the model output data. Typically, several Tensor based transforms composed by Compose.

  • key_metric (ignite.metric) – compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the checkpoint into files.

  • additional_metrics (dict) – more Ignite metrics that also attach to Ignite Engine.

  • handlers (list) – every handler is a set of Ignite Event-Handlers, must have attach function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc.

run()[source]

Execute training, validation or evaluation based on Ignite Engine.

Return type

None

Trainer

class monai.engines.Trainer(device, max_epochs, amp, data_loader, prepare_batch=<function default_prepare_batch>, iteration_update=None, post_transform=None, key_metric=None, additional_metrics=None, handlers=None)[source]

Base class for all kinds of trainers, inherits from Workflow.

run()[source]

Execute training based on Ignite Engine. If call this function multiple times, it will continuously run from the previous state.

Return type

None

SupervisedTrainer

class monai.engines.SupervisedTrainer(device, max_epochs, train_data_loader, network, optimizer, loss_function, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=<monai.inferers.inferer.SimpleInferer object>, amp=True, post_transform=None, key_train_metric=None, additional_metrics=None, train_handlers=None)[source]

Standard supervised training method with image and label, inherits from trainer and Workflow.

Parameters
  • device (torch.device) – an object representing the device on which to run.

  • max_epochs (int) – the total epoch number for engine to run, validator and evaluator have only 1 epoch.

  • train_data_loader (torch.DataLoader) – Ignite engine use data_loader to run, must be torch.DataLoader.

  • network (Network) – to train with this network.

  • optimizer (Optimizer) – the optimizer associated to the network.

  • loss_function (Loss) – the loss function associated to the optimizer.

  • prepare_batch (Callable) – function to parse image and label for current iteration.

  • iteration_update (Optional[Callable]) – the callable function for every iteration, expect to accept engine and batchdata as input parameters. if not provided, use self._iteration() instead.

  • inferer (Inferer) – inference method that execute model forward on input data, like: SlidingWindow, etc.

  • amp (bool) – whether to enable auto-mixed-precision training, reserved.

  • post_transform (Transform) – execute additional transformation for the model output data. Typically, several Tensor based transforms composed by Compose.

  • key_train_metric (ignite.metric) – compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.

  • additional_metrics (dict) – more Ignite metrics that also attach to Ignite Engine.

  • train_handlers (list) – every handler is a set of Ignite Event-Handlers, must have attach function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc.

Evaluator

class monai.engines.Evaluator(device, val_data_loader, prepare_batch=<function default_prepare_batch>, iteration_update=None, post_transform=None, key_val_metric=None, additional_metrics=None, val_handlers=None)[source]

Base class for all kinds of evaluators, inherits from Workflow.

Parameters
  • device (torch.device) – an object representing the device on which to run.

  • val_data_loader (torch.DataLoader) – Ignite engine use data_loader to run, must be torch.DataLoader.

  • prepare_batch (Callable) – function to parse image and label for current iteration.

  • iteration_update (Optional[Callable]) – the callable function for every iteration, expect to accept engine and batchdata as input parameters. if not provided, use self._iteration() instead.

  • post_transform (Transform) – execute additional transformation for the model output data. Typically, several Tensor based transforms composed by Compose.

  • key_val_metric (ignite.metric) – compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files.

  • additional_metrics (dict) – more Ignite metrics that also attach to Ignite Engine.

  • val_handlers (list) – every handler is a set of Ignite Event-Handlers, must have attach function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc.

run(global_epoch=1)[source]

Execute validation/evaluation based on Ignite Engine.

Parameters

global_epoch (int) – the overall epoch if during a training. evaluator engine can get it from trainer.

SupervisedEvaluator

class monai.engines.SupervisedEvaluator(device, val_data_loader, network, prepare_batch=<function default_prepare_batch>, iteration_update=None, inferer=<monai.inferers.inferer.SimpleInferer object>, post_transform=None, key_val_metric=None, additional_metrics=None, val_handlers=None)[source]

Standard supervised evaluation method with image and label(optional), inherits from evaluator and Workflow.

Parameters
  • device (torch.device) – an object representing the device on which to run.

  • val_data_loader (torch.DataLoader) – Ignite engine use data_loader to run, must be torch.DataLoader.

  • network (Network) – use the network to run model forward.

  • prepare_batch (Callable) – function to parse image and label for current iteration.

  • iteration_update (Optional[Callable]) – the callable function for every iteration, expect to accept engine and batchdata as input parameters. if not provided, use self._iteration() instead.

  • inferer (Inferer) – inference method that execute model forward on input data, like: SlidingWindow, etc.

  • post_transform (Transform) – execute additional transformation for the model output data. Typically, several Tensor based transforms composed by Compose.

  • key_val_metric (ignite.metric) – compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files.

  • additional_metrics (dict) – more Ignite metrics that also attach to Ignite Engine.

  • val_handlers (list) – every handler is a set of Ignite Event-Handlers, must have attach function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc.