Source code for pina._src.callback.optim.switch_scheduler
"""Module for the SwitchScheduler callback."""
from lightning.pytorch.callbacks import Callback
from pina._src.optim.scheduler_interface import SchedulerInterface
from pina._src.core.utils import check_consistency, check_positive_integer
[docs]
class SwitchScheduler(Callback):
"""
Lightning callback for dynamically replacing schedulers during training.
This callback enables switching to new scheduler(s) at a specified epoch
without interrupting the training loop. It is useful for staged training
strategies where different learning rate policies are applied sequentially.
"""
def __init__(self, new_schedulers, epoch_switch):
"""
Initialization of the :class:`SwitchScheduler` class.
:param new_schedulers: The scheduler or list of schedulers to switch to.
Use a single scheduler for single-model solvers, or a list of
schedulers when working with multiple models.
:type new_schedulers: SchedulerInterface | list[SchedulerInterface]
:param int epoch_switch: The epoch at which the scheduler switch occurs.
:raises AssertionError: If ``epoch_switch`` is not a positive integer.
:raises ValueError: If any of the provided schedulers are not instances
of :class:`pina.optim.SchedulerInterface`.
Example:
>>> scheduler = TorchScheduler(
>>> torch.optim.lr_scheduler.StepLR, step_size=5
>>> )
>>> switch_callback = SwitchScheduler(
>>> new_schedulers=scheduler, epoch_switch=10
>>> )
"""
super().__init__()
# Check consistency
check_positive_integer(epoch_switch, strict=True)
check_consistency(new_schedulers, SchedulerInterface)
# If new_schedulers is not a list, convert it to a list
if not isinstance(new_schedulers, list):
new_schedulers = [new_schedulers]
# Store the new schedulers and epoch switch
self._new_schedulers = new_schedulers
self._epoch_switch = epoch_switch
[docs]
def on_train_epoch_start(self, trainer, __):
"""
Switch the scheduler at the start of the specified training epoch.
:param Trainer trainer: The trainer object managing
the training process.
:param __: Placeholder argument, not used.
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
schedulers = []
# Hook the new schedulers to the model parameters
for idx, scheduler in enumerate(self._new_schedulers):
scheduler.hook(trainer.solver._pina_optimizers[idx])
schedulers.append(scheduler)
# Update the trainer's scheduler configs
trainer.lr_scheduler_configs[idx].scheduler = scheduler.instance
# Update the solver's schedulers
trainer.solver._pina_schedulers = schedulers