Source code for pina._src.solver.mixin.manual_optimization_mixin

"""Module for the manual optimization mixin class."""


[docs] class ManualOptimizationMixin: """ Mixin that handles Lightning manual optimization loops, useful for solvers that require explicit control over optimization steps, such as those with multiple optimizers or custom training loops. Designed to be used in combination with any solver inheriting from :class:`~pina._src.solver.base_solver.BaseSolver`. """ def _init_manual_optimization(self): """ Disable Lightning's automatic optimization. """ self.automatic_optimization = False
[docs] def training_step(self, batch, batch_idx): """ Solver training step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param int batch_idx: The index of the current batch. :return: The loss of the training step. :rtype: torch.Tensor """ # Zero the gradients of all optimizers for opt in self.optimizers: opt.instance.zero_grad() # Perform the forward pass and compute the loss loss = super().training_step(batch, batch_idx) # Perform the backward pass self.manual_backward(loss) # Step the optimizers and schedulers for opt, sched in zip(self.optimizers, self.schedulers): opt.instance.step() sched.instance.step() return loss
[docs] def on_train_batch_end(self, outputs, batch, batch_idx): """ Keep Lightning's manual optimization progress counters in sync. This hook increments the completed optimization-step counter used by Lightning's manual optimization loop, then delegates to the parent implementation. :param torch.Tensor outputs: The loss of the training step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param int batch_idx: The index of the current batch. :return: The result returned by the parent class implementation. :rtype: Any """ # Sync the manual optimization progress counters in Lightning's loop epoch_loop = self.trainer.fit_loop.epoch_loop epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 return super().on_train_batch_end(outputs, batch, batch_idx)