Source code for monai.handlers.earlystop_handler

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from import Callable
from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.utils import min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
EarlyStopping, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EarlyStopping")

    from ignite.engine import Engine
    Engine, _ = optional_import(
        "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"

[docs] class EarlyStopHandler: """ EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events. It‘s based on the `EarlyStopping` handler in ignite. Args: patience: number of events to wait if no improvement and then stop the training. score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine` object that the handler attached, can be a trainer or validator, and return a score `float`. an improvement is considered if the score is higher. trainer: trainer engine to stop the run if no improvement, if None, must call `set_trainer()` before training. min_delta: a minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to `min_delta`, will count as no improvement. cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise, it defines an increase after the last event, default to False. epoch_level: check early stopping for every epoch or every iteration of the attached engine, `True` is epoch level, `False` is iteration level, default to epoch level. Note: If in distributed training and uses loss value of every iteration to detect early stopping, the values may be different in different ranks. When using this handler with distributed training, please also note that to prevent "dist.destroy_process_group()" hangs, you can use an "all_reduce" operation to synchronize the stop signal across all ranks. The mechanism can be implemented in the `score_function`. The following is an example: .. code-block:: python import os import torch import torch.distributed as dist def score_function(engine): val_metric = engine.state.metrics["val_mean_dice"] if dist.is_initialized(): device = torch.device("cuda:" + os.environ["LOCAL_RANK"]) val_metric = torch.tensor([val_metric]).to(device) dist.all_reduce(val_metric, op=dist.ReduceOp.SUM) val_metric /= dist.get_world_size() return val_metric.item() return val_metric User may attach this handler to validator engine to detect validation metrics and stop the training, in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine. """ def __init__( self, patience: int, score_function: Callable, trainer: Engine | None = None, min_delta: float = 0.0, cumulative_delta: bool = False, epoch_level: bool = True, ) -> None: self.patience = patience self.score_function = score_function self.min_delta = min_delta self.cumulative_delta = cumulative_delta self.epoch_level = epoch_level self._handler = None if trainer is not None: self.set_trainer(trainer=trainer)
[docs] def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if self.epoch_level: engine.add_event_handler(Events.EPOCH_COMPLETED, self) else: engine.add_event_handler(Events.ITERATION_COMPLETED, self)
[docs] def set_trainer(self, trainer: Engine) -> None: """ Set trainer to execute early stop if not setting properly in `__init__()`. """ self._handler = EarlyStopping( patience=self.patience, score_function=self.score_function, trainer=trainer, min_delta=self.min_delta, cumulative_delta=self.cumulative_delta, )
def __call__(self, engine: Engine) -> None: if self._handler is None: raise RuntimeError("please set trainer in __init__() or call set_trainer() before training.") self._handler(engine)