CausalPINN#

class CausalPINN(problem, model, optimizer=None, scheduler=None, weighting=None, loss=None, eps=100)[source]#

Bases: PINN

Causal Physics-Informed Neural Network (CausalPINN) solver class. This class implements the Causal Physics-Informed Neural Network solver, using a user specified model to solve a specific problem. It can be used to solve both forward and inverse problems.

The Causal Physics-Informed Neural Network solver aims to find the solution \(\mathbf{u}:\Omega\rightarrow\mathbb{R}^m\) of a differential problem:

\[\begin{split}\begin{cases} \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, \mathbf{x}\in\partial\Omega \end{cases}\end{split}\]

minimizing the loss function:

\[\mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t} \omega_{i}\mathcal{L}_r(t_i),\]

where:

\[\mathcal{L}_r(t) = \frac{1}{N}\sum_{i=1}^N \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i, t)) + \frac{1}{N}\sum_{i=1}^N \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i, t))\]

and,

\[\omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right).\]

\(\epsilon\) is an hyperparameter, set by default to \(100\), while \(\mathcal{L}\) is a specific loss function, typically the MSE:

\[\mathcal{L}(v) = \| v \|^2_2.\]

See also

Original reference: Wang, Sifan, Shyam Sankaran, and Paris Perdikaris. Respecting causality for training physics-informed neural networks. Computer Methods in Applied Mechanics and Engineering 421 (2024):116813. DOI: 10.1016.

Note

This class is only compatible with problems that inherit from the TimeDependentProblem class.

Initialization of the CausalPINN class.

Parameters:
Raises:

ValueError – If the problem is not a TimeDependentProblem.

loss_phys(samples, equation)[source]#

Computes the physics loss for the physics-informed solver based on the provided samples and equation.

Parameters:
Returns:

The computed physics loss.

Return type:

LabelTensor

property eps#

The exponential decay parameter.

Returns:

The exponential decay parameter.

Return type:

float