# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from bisect import bisect_right
from collections.abc import Callable
from typing import TYPE_CHECKING
from monai.config import IgniteInfo
from monai.utils import min_version, optional_import
if TYPE_CHECKING:
from ignite.engine import Engine, Events
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
[docs]
class ParamSchedulerHandler:
"""
General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or
multistep function. One can also pass Callables to have customized scheduling logic.
Args:
parameter_setter (Callable): Function that sets the required parameter
value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')
or Callable for custom logic.
vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.
epoch_level (bool): Whether the step is based on epoch or iteration. Defaults to False.
name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.
"""
def __init__(
self,
parameter_setter: Callable,
value_calculator: str | Callable,
vc_kwargs: dict,
epoch_level: bool = False,
name: str | None = None,
event: str | None = None,
):
self.epoch_level = epoch_level
self.event = event if event is not None else Events.ITERATION_COMPLETED
self._calculators = {
"linear": self._linear,
"exponential": self._exponential,
"step": self._step,
"multistep": self._multistep,
}
self._parameter_setter = parameter_setter
self._vc_kwargs = vc_kwargs
self._value_calculator = self._get_value_calculator(value_calculator=value_calculator)
self.logger = logging.getLogger(name)
self._name = name
def _get_value_calculator(self, value_calculator):
if isinstance(value_calculator, str):
return self._calculators[value_calculator]
if callable(value_calculator):
return value_calculator
raise ValueError(
f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable."
)
def __call__(self, engine: Engine) -> None:
if self.epoch_level:
self._vc_kwargs["current_step"] = engine.state.epoch
else:
self._vc_kwargs["current_step"] = engine.state.iteration
new_value = self._value_calculator(**self._vc_kwargs)
self._parameter_setter(new_value)
[docs]
def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine that is used for training.
"""
if self._name is None:
self.logger = engine.logger
engine.add_event_handler(self.event, self)
@staticmethod
def _linear(
initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int
) -> float:
"""
Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an
additional step_one steps passed. Continues the trend until it reaches max_value.
Args:
initial_value (float): Starting value of the parameter.
step_constant (int): Step index until parameter's value is kept constant.
step_max_value (int): Step index at which parameter's value becomes max_value.
max_value (float): Max parameter value.
current_step (int): Current step index.
Returns:
float: new parameter value
"""
if current_step <= step_constant:
delta = 0.0
elif current_step > step_max_value:
delta = max_value - initial_value
else:
delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant)
return initial_value + delta
@staticmethod
def _exponential(initial_value: float, gamma: float, current_step: int) -> float:
"""
Decays the parameter value by gamma every step.
Based on the closed form of ExponentialLR from Pytorch:
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html.
Args:
initial_value (float): Starting value of the parameter.
gamma (float): Multiplicative factor of parameter value decay.
current_step (int): Current step index.
Returns:
float: new parameter value
"""
return initial_value * gamma**current_step
@staticmethod
def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float:
"""
Decays the parameter value by gamma every step_size.
Based on StepLR from Pytorch:
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html.
Args:
initial_value (float): Starting value of the parameter.
gamma (float): Multiplicative factor of parameter value decay.
step_size (int): Period of parameter value decay.
current_step (int): Current step index.
Returns
float: new parameter value
"""
return initial_value * gamma ** (current_step // step_size)
@staticmethod
def _multistep(initial_value: float, gamma: float, milestones: list[int], current_step: int) -> float:
"""
Decays the parameter value by gamma once the number of steps reaches one of the milestones.
Based on MultiStepLR from Pytorch.
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html.
Args:
initial_value (float): Starting value of the parameter.
gamma (float): Multiplicative factor of parameter value decay.
milestones (List[int]): List of step indices. Must be increasing.
current_step (int): Current step index.
Returns:
float: new parameter value
"""
return initial_value * gamma ** bisect_right(milestones, current_step)