Source code for pina._src.weighting.ntk_weighting

"""Module for Neural Tangent Kernel Class"""

import torch
from pina._src.weighting.base_weighting import BaseWeighting
from pina._src.core.utils import check_consistency, in_range


[docs] class NeuralTangentKernelWeighting(BaseWeighting): """ The Neural Tangent Kernel (NTK) weighting strategy. This weighting scheme dynamically adjusts the contribution of each loss term during training by leveraging gradient information with respect to the model parameters. For each loss component, the norm of its gradient is computed and used to derive relative importance weights. The resulting weights are smoothed over time using an exponential moving average controlled by the parameter ``alpha``. .. seealso:: **Original reference**: Wang, Sifan, Xinling Yu, and Paris Perdikaris. *When and why PINNs fail to train: A neural tangent kernel perspective*. Journal of Computational Physics 449 (2022): 110768. DOI: `10.1016 <https://doi.org/10.1016/j.jcp.2021.110768>`_. """ def __init__(self, update_every_n_epochs=1, alpha=0.5): """ Initialization of the :class:`NeuralTangentKernelWeighting` class. :param int update_every_n_epochs: The number of training epochs between weight updates. If set to 1, the weights are updated at every epoch. This parameter is ignored by static weighting schemes. Default is ``1``. :param float alpha: The parameter controlling the exponential moving average of the weights. It must be in the range [0, 1], where a value of ``0.0`` means that only the current gradient norms are used to compute the weights, and a value of ``1.0`` means that only the last saved weights are used. Default is ``0.5``. :raises ValueError: If ``alpha`` is not a float. :raises ValueError: If ``alpha`` is not between 0.0 and 1.0 (inclusive). """ super().__init__( update_every_n_epochs=update_every_n_epochs, aggregator="sum" ) # Check consistency check_consistency(alpha, float) if not in_range(alpha, [0.0, 1.0], strict=False): raise ValueError( "The alpha parameter must be between 0.0 and 1.0 (inclusive)." f" Got {alpha}." ) # Initialization self.alpha = alpha self.weights = {}
[docs] def update_weights(self, losses): """ Update the weights based on the current losses. This method defines how the weighting strategy adapts over time. It is responsible for computing and storing updated weights that will be used during aggregation. :param dict losses: The mapping from loss names to loss tensors. :return: The updated weights. :rtype: dict """ # Get model parameters and define a dictionary to store the norms params = [p for p in self.solver.model.parameters() if p.requires_grad] norms = {} # Iterate over conditions for condition, loss in losses.items(): # Compute gradients grads = torch.autograd.grad( loss, params, retain_graph=True, allow_unused=True, ) # Compute norms norms[condition] = torch.cat( [g.flatten() for g in grads if g is not None] ).norm() return { condition: self.alpha * self.last_saved_weights().get(condition, 1) + (1 - self.alpha) * norms[condition] / sum(norms.values()) for condition in losses }