Source code for pina._src.callback.optim.switch_optimizer
"""Module for the SwitchOptimizer callback."""
from lightning.pytorch.callbacks import Callback
from pina._src.optim.optimizer_interface import OptimizerInterface
from pina._src.core.utils import check_consistency, check_positive_integer
[docs]
class SwitchOptimizer(Callback):
"""
Lightning callback for dynamically replacing optimizers during training.
This callback enables switching to one or more new optimizers at a specified
epoch without restarting the training loop. It is particularly useful for
staged optimization strategies (e.g., coarse-to-fine training or optimizer
warm-up phases), where different optimizers are applied sequentially.
At the target epoch, the provided optimizers are hooked to the model
parameters and replace the current optimizers in both the PINA solver and
the Lightning trainer strategy.
"""
def __init__(self, new_optimizers, epoch_switch):
"""
Initialization of the :class:`SwitchOptimizer` class.
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` instance or a list of them
for multiple model solver.
:type new_optimizers: pina.optim.OptimizerInterface | list
:param int epoch_switch: The epoch at which the optimizer switch occurs.
:raises AssertionError: If ``epoch_switch`` is not a positive integer.
:raises ValueError: If any of the provided optimizers are not instances
of :class:`pina.optim.OptimizerInterface`.
Example:
>>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
>>> switch_callback = SwitchOptimizer(
>>> new_optimizers=optimizer, epoch_switch=10
>>> )
"""
super().__init__()
# Check consistency
check_positive_integer(epoch_switch, strict=True)
check_consistency(new_optimizers, OptimizerInterface)
# If new_optimizers is not a list, convert it to a list
if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
# Store the new optimizers and epoch switch
self._new_optimizers = new_optimizers
self._epoch_switch = epoch_switch
[docs]
def on_train_epoch_start(self, trainer, __):
"""
Switch the optimizer at the start of the specified training epoch.
:param Trainer trainer: The trainer object managing the training
process.
:param __: Placeholder argument, not used.
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
optims = []
# Hook the new optimizers to the model parameters
for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver._pina_models[idx].parameters())
optims.append(optim)
# Update the solver's optimizers
trainer.solver._pina_optimizers = optims
# Update the trainer's strategy optimizers
trainer.strategy.optimizers = [o.instance for o in optims]