Source code for pina.callback.optim.switch_scheduler

"""Module for the SwitchScheduler callback."""

from lightning.pytorch.callbacks import Callback
from ...optim import TorchScheduler
from ...utils import check_consistency, check_positive_integer


[docs] class SwitchScheduler(Callback): """ Callback to switch scheduler during training. """ def __init__(self, new_schedulers, epoch_switch): """ This callback allows switching between different schedulers during training, enabling the exploration of multiple optimization strategies without interrupting the training process. :param new_schedulers: The scheduler or list of schedulers to switch to. Use a single scheduler for single-model solvers, or a list of schedulers when working with multiple models. :type new_schedulers: pina.optim.TorchScheduler | list[pina.optim.TorchScheduler] :param int epoch_switch: The epoch at which the scheduler switch occurs. :raise AssertionError: If epoch_switch is less than 1. :raise ValueError: If each scheduler in ``new_schedulers`` is not an instance of :class:`pina.optim.TorchScheduler`. Example: >>> scheduler = TorchScheduler( >>> torch.optim.lr_scheduler.StepLR, step_size=5 >>> ) >>> switch_callback = SwitchScheduler( >>> new_schedulers=scheduler, epoch_switch=10 >>> ) """ super().__init__() # Check if epoch_switch is greater than 1 check_positive_integer(epoch_switch - 1, strict=True) # If new_schedulers is not a list, convert it to a list if not isinstance(new_schedulers, list): new_schedulers = [new_schedulers] # Check consistency for scheduler in new_schedulers: check_consistency(scheduler, TorchScheduler) # Store the new schedulers and epoch switch self._new_schedulers = new_schedulers self._epoch_switch = epoch_switch
[docs] def on_train_epoch_start(self, trainer, __): """ Switch the scheduler at the start of the specified training epoch. :param lightning.pytorch.Trainer trainer: The trainer object managing the training process. :param __: Placeholder argument (not used). """ # Check if the current epoch matches the switch epoch if trainer.current_epoch == self._epoch_switch: schedulers = [] # Hook the new schedulers to the model parameters for idx, scheduler in enumerate(self._new_schedulers): scheduler.hook(trainer.solver._pina_optimizers[idx]) schedulers.append(scheduler) # Update the trainer's scheduler configs trainer.lr_scheduler_configs[idx].scheduler = scheduler.instance # Update the solver's schedulers trainer.solver._pina_schedulers = schedulers