Source code for pina.loss.self_adaptive_weighting
"""Module for Self-Adaptive Weighting class."""
import torch
from .weighting_interface import WeightingInterface
[docs]
class SelfAdaptiveWeighting(WeightingInterface):
"""
A self-adaptive weighting scheme to tackle the imbalance among the loss
components. This formulation equalizes the gradient norms of the losses,
preventing bias toward any particular term during training.
.. seealso::
**Original reference**:
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
*Simulating Three-dimensional Turbulence with Physics-informed Neural
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
<https://arxiv.org/abs/2507.08972>`_
"""
def __init__(self, update_every_n_epochs=1):
"""
Initialization of the :class:`SelfAdaptiveWeighting` class.
:param int update_every_n_epochs: The number of training epochs between
weight updates. If set to 1, the weights are updated at every epoch.
Default is 1.
"""
super().__init__(update_every_n_epochs=update_every_n_epochs)
[docs]
def weights_update(self, losses):
"""
Update the weighting scheme based on the given losses.
:param dict losses: The dictionary of losses.
:return: The updated weights.
:rtype: dict
"""
# Define a dictionary to store the norms of the gradients
losses_norm = {}
# Compute the gradient norms for each loss component
for condition, loss in losses.items():
loss.backward(retain_graph=True)
grads = torch.cat(
[p.grad.flatten() for p in self.solver.model.parameters()]
)
losses_norm[condition] = grads.norm()
# Update the weights
return {
condition: sum(losses_norm.values()) / losses_norm[condition]
for condition in losses
}