Source code for pina._src.weighting.self_adaptive_weighting
"""Module for Self-Adaptive Weighting class."""
import torch
from pina._src.weighting.base_weighting import BaseWeighting
[docs]
class SelfAdaptiveWeighting(BaseWeighting):
"""
The self-adaptive weighting strategy based on gradient norm balancing.
This scheme dynamically adjusts the weights assigned to each loss term by
computing the norm of their gradients with respect to the model parameters.
The resulting weights are chosen to counterbalance disparities in gradient
magnitudes, promoting a more uniform contribution of all loss components
during optimization.
In practice, loss terms with smaller gradient norms are assigned larger
weights, while those with larger gradients are down-weighted. This helps
mitigate training imbalance and prevents dominance of specific loss terms.
.. seealso::
**Original reference**:
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
*Simulating Three-dimensional Turbulence with Physics-informed Neural
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
<https://arxiv.org/abs/2507.08972>`_
"""
def __init__(self, update_every_n_epochs=1):
"""
Initialization of the :class:`SelfAdaptiveWeighting` class.
:param int update_every_n_epochs: The number of training epochs between
weight updates. If set to 1, the weights are updated at every epoch.
This parameter is ignored by static weighting schemes.
Default is ``1``.
"""
super().__init__(
update_every_n_epochs=update_every_n_epochs, aggregator="sum"
)
[docs]
def update_weights(self, losses):
"""
Update the weights based on the current losses.
This method defines how the weighting strategy adapts over time. It is
responsible for computing and storing updated weights that will be used
during aggregation.
:param dict losses: The mapping from loss names to loss tensors.
:return: The updated weights.
:rtype: dict
"""
# Get model parameters and define a dictionary to store the norms
params = [p for p in self.solver.model.parameters() if p.requires_grad]
norms = {}
# Iterate over conditions
for condition, loss in losses.items():
# Compute gradients
grads = torch.autograd.grad(
loss,
params,
retain_graph=True,
allow_unused=True,
)
# Compute norms
norms[condition] = torch.cat(
[g.flatten() for g in grads if g is not None]
).norm()
# Update the weights
return {
condition: sum(norms.values()) / norms[condition]
for condition in losses
}