NeuralTangentKernelWeighting#
Module for Neural Tangent Kernel Class
- class NeuralTangentKernelWeighting(model, alpha=0.5)[source]#
Bases:
WeightingInterface
A neural tangent kernel scheme for weighting different losses to boost the convergence.
See also
Original reference: Wang, Sifan, Xinling Yu, and Paris Perdikaris. When and why PINNs fail to train: A neural tangent kernel perspective. Journal of Computational Physics 449 (2022): 110768. DOI: 10.1016.
Initialization of the
NeuralTangentKernelWeighting
class.- Parameters:
model (torch.nn.Module) – The neural network model.
alpha (float) – The alpha parameter.
- Raises:
ValueError – If
alpha
is not between 0 and 1 (inclusive).
- aggregate(losses)[source]#
Weight the losses according to the Neural Tangent Kernel algorithm.
- Parameters:
input (dict(torch.Tensor)) – The dictionary of losses.
- Returns:
The losses aggregation. It should be a scalar Tensor.
- Return type: