Source code for pina.loss.self_adaptive_weighting
"""Module for Self-Adaptive Weighting class."""
import torch
from .weighting_interface import WeightingInterface
[docs]
class SelfAdaptiveWeighting(WeightingInterface):
"""
A self-adaptive weighting scheme to tackle the imbalance among the loss
components. This formulation equalizes the gradient norms of the losses,
preventing bias toward any particular term during training.
.. seealso::
**Original reference**:
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
*Simulating Three-dimensional Turbulence with Physics-informed Neural
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
<https://arxiv.org/abs/2507.08972>`_
"""
def __init__(self, update_every_n_epochs=1):
"""
Initialization of the :class:`SelfAdaptiveWeighting` class.
:param int update_every_n_epochs: The number of training epochs between
weight updates. If set to 1, the weights are updated at every epoch.
Default is 1.
"""
super().__init__(update_every_n_epochs=update_every_n_epochs)
[docs]
def weights_update(self, losses):
"""
Update the weighting scheme based on the given losses.
:param dict losses: The dictionary of losses.
:return: The updated weights.
:rtype: dict
"""
# Get model parameters and define a dictionary to store the norms
params = [p for p in self.solver.model.parameters() if p.requires_grad]
norms = {}
# Iterate over conditions
for condition, loss in losses.items():
# Compute gradients
grads = torch.autograd.grad(
loss,
params,
retain_graph=True,
allow_unused=True,
)
# Compute norms
norms[condition] = torch.cat(
[g.flatten() for g in grads if g is not None]
).norm()
# Update the weights
return {
condition: sum(norms.values()) / norms[condition]
for condition in losses
}