Source code for monai.engines.evaluator

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple

import torch
from torch.utils.data import DataLoader

from monai.engines.utils import CommonKeys as Keys
from monai.engines.utils import default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import ensure_tuple, exact_version, optional_import

if TYPE_CHECKING:
    from ignite.engine import Engine
    from ignite.metrics import Metric
else:
    Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine")
    Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")

__all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"]


[docs]class Evaluator(Workflow): """ Base class for all kinds of evaluators, inherits from Workflow. Args: device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be torch.DataLoader. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse image and label for current iteration. iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. post_transform: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. """ def __init__( self, device: torch.device, val_data_loader: DataLoader, epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, post_transform: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, val_handlers: Optional[Sequence] = None, amp: bool = False, ) -> None: super().__init__( device=device, max_epochs=1, data_loader=val_data_loader, epoch_length=epoch_length, non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, post_transform=post_transform, key_metric=key_val_metric, additional_metrics=additional_metrics, handlers=val_handlers, amp=amp, )
[docs] def run(self, global_epoch: int = 1) -> None: """ Execute validation/evaluation based on Ignite Engine. Args: global_epoch: the overall epoch if during a training. evaluator engine can get it from trainer. """ # init env value for current validation process self.state.max_epochs = global_epoch self.state.epoch = global_epoch - 1 self.state.iteration = 0 super().run()
def get_validation_stats(self) -> Dict[str, float]: return {"best_validation_metric": self.state.best_metric, "best_validation_epoch": self.state.best_metric_epoch}
[docs]class SupervisedEvaluator(Evaluator): """ Standard supervised evaluation method with image and label(optional), inherits from evaluator and Workflow. Args: device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be torch.DataLoader. network: use the network to run model forward. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse image and label for current iteration. iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. post_transform: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. """ def __init__( self, device: torch.device, val_data_loader: DataLoader, network: torch.nn.Module, epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, inferer: Optional[Inferer] = None, post_transform: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, val_handlers: Optional[Sequence] = None, amp: bool = False, ) -> None: super().__init__( device=device, val_data_loader=val_data_loader, epoch_length=epoch_length, non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, post_transform=post_transform, key_val_metric=key_val_metric, additional_metrics=additional_metrics, val_handlers=val_handlers, amp=amp, ) self.network = network self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - PRED: prediction result of model. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = tuple() kwargs: Dict = dict() else: inputs, targets, args, kwargs = batch # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network, *args, **kwargs) else: predictions = self.inferer(inputs, self.network, *args, **kwargs) return {Keys.IMAGE: inputs, Keys.LABEL: targets, Keys.PRED: predictions}
[docs]class EnsembleEvaluator(Evaluator): """ Ensemble evaluation for multiple models, inherits from evaluator and Workflow. It accepts a list of models for inference and outputs a list of predictions for further operations. Args: device: an object representing the device on which to run. val_data_loader: Ignite engine use data_loader to run, must be torch.DataLoader. epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. networks: use the networks to run model forward in order. pred_keys: the keys to store every prediction data. the length must exactly match the number of networks. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse image and label for current iteration. iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. post_transform: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. """ def __init__( self, device: torch.device, val_data_loader: DataLoader, networks: Sequence[torch.nn.Module], pred_keys: Sequence[str], epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, inferer: Optional[Inferer] = None, post_transform: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, val_handlers: Optional[Sequence] = None, amp: bool = False, ) -> None: super().__init__( device=device, val_data_loader=val_data_loader, epoch_length=epoch_length, non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, post_transform=post_transform, key_val_metric=key_val_metric, additional_metrics=additional_metrics, val_handlers=val_handlers, amp=amp, ) self.networks = ensure_tuple(networks) self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: - IMAGE: image Tensor data for model input, already moved to device. - LABEL: label Tensor data corresponding to the image, already moved to device. - pred_keys[0]: prediction result of network 0. - pred_keys[1]: prediction result of network 1. - ... ... - pred_keys[N]: prediction result of network N. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: ValueError: When ``batchdata`` is None. """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = tuple() kwargs: Dict = dict() else: inputs, targets, args, kwargs = batch # execute forward computation predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets} for idx, network in enumerate(self.networks): network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) else: predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) return predictions