Source code for pina._src.solver.causal_physics_informed_single_model_solver
"""Module for the causal physics-informed single-model solver class."""
import torch
from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.core.utils import check_consistency, check_positive_integer
from pina._src.problem.time_dependent_problem import TimeDependentProblem
from pina._src.solver.single_model_solver import SingleModelSolver
from pina._src.core.label_tensor import LabelTensor
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
)
[docs]
class CausalPhysicsInformedSingleModelSolver(
PhysicsInformedMixin, SingleModelSolver
):
r"""
Single-model solver for causal physics-informed learning problems.
This solver approximates the solution of a time-dependent differential
problem using a single model and a causality-aware training objective. It is
intended for problems whose conditions include equation residuals and
boundary residuals evaluated across ordered time snapshots. It can be used
only for forward problems, due to the causal nature of the training
objective.
Given a model :math:`\mathcal{M}`, the predicted solution is
.. math::
\hat{\mathbf{u}}(\mathbf{x}, t) = \mathcal{M}(\mathbf{x}, t).
The solver minimizes a causal residual loss that weights each time snapshot
according to the residuals accumulated at previous times. For a time
dependent problem with governing equation operator :math:`\mathcal{A}` in
the domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
boundary :math:`\partial\Omega`, the objective can be written as
.. math::
\mathcal{L}_{\mathrm{problem}} = \frac{1}{N_t} \sum_{i=1}^{N_t}
\omega_i \mathcal{L}_r(t_i),
where the residual loss at time :math:`t` is
.. math::
\mathcal{L}_r(t) = \frac{1}{N_{\Omega}} \sum_{j=1}^{N_{\Omega}}
\mathcal{L}\left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_j, t) \right)
+ \frac{1}{N_{\partial\Omega}} \sum_{j=1}^{N_{\partial\Omega}}
\mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_j, t)
\right).
The causal weights are defined as
.. math::
\omega_i = \exp \left( -\epsilon \sum_{k=1}^{i-1} \mathcal{L}_r(t_k)
\right),
where :math:`\epsilon` is a hyperparameter controlling the strength of the
causal weighting, and :math:`\mathcal{L}` is the selected loss function,
typically the mean squared error.
.. seealso::
**Original reference**: Wang, S., Sankaran, S., & Perdikaris, P. (2024).
*Respecting causality for training physics-informed neural networks.*
Computer Methods in Applied Mechanics and Engineering, 421, 116813.
DOI: `10.1016/j.cma.2024.116813
<https://doi.org/10.1016/j.cma.2024.116813>`_.
.. note::
This solver is compatible only with problems inheriting from
:class:`~pina.problem.time_dependent_problem.TimeDependentProblem`.
"""
# Accepted conditions types for this solver
accepted_conditions_types = (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
def __init__(
self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
loss=None,
eps=100,
n_steps=10,
regularized_conditions=None,
):
"""
Initialization of the :class:`CausalPhysicsInformedSingleModelSolver`
class.
:param BaseProblem problem: The problem to be solved.
:param torch.nn.Module model: The model used by the solver.
:param TorchOptimizer optimizer: The optimizer used by the solver.
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
of ``0.001`` is used. Default is ``None``.
:param TorchScheduler scheduler: The scheduler used by the solver.
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
with a factor of ``1.0`` is used. Default is ``None``.
:param BaseWeighting weighting: The weighting strategy used to combine
condition losses. If ``None``, no weighting is applied. Default is
``None``.
:param loss: The loss function used to compute residual losses.
If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
:param eps: The exponential decay parameter. Default is ``100``.
:type eps: float | int
:param int n_steps: The number of equispaced temporal steps used to
compute the causal loss. Default is ``10``.
:param regularized_conditions: The names of the conditions that should
receive causal regularization. Default is ``None``.
:raises ValueError: If the problem is not time-dependent.
:raises ValueError: If the user does not specify any regularized
conditions.
:raises ValueError: If any of the specified ``regularized_conditions``
are not present in the ``problem``'s conditions.
:raises ValueError: If ``eps`` is not a float or int.
:raises ValueError: If ``n_steps`` is not a positive integer.
"""
# Ensure the problem is time-dependent
if not isinstance(problem, TimeDependentProblem):
raise ValueError(
"Causal physics-informed solvers require the problem to be an "
f"instance of TimeDependentProblem. Got {type(problem)}."
)
# Ensure the user specified valid regularized conditions
if regularized_conditions is None:
raise ValueError(
"Causal physics-informed solvers require the user to specify "
"the conditions that should receive causal regularization. "
"Apply causal regularization only to time-dependent conditions."
)
# Check consistency
check_consistency(eps, (int, float))
check_consistency(regularized_conditions, str)
check_positive_integer(n_steps, strict=True)
# Map conditions to list if a single condition is provided
if not isinstance(regularized_conditions, (list, tuple)):
regularized_conditions = [regularized_conditions]
# Ensure that all regularized conditions are present in the problem
problem_conditions = set(problem.conditions.keys())
for condition in regularized_conditions:
if condition not in problem_conditions:
raise ValueError(
f"Condition '{condition}' is not present in the problem."
)
# Initialize the parent class
SingleModelSolver.__init__(
self,
problem=problem,
model=model,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss,
use_lt=True,
)
# Initialize the causal regularization parameters
self.eps = eps
self.n_steps = n_steps
self.regularized_conditions = regularized_conditions
def _compute_condition_loss(self, condition, data, batch_idx):
"""
Compute the scalar loss for a given condition and its data.
:param BaseCondition condition: The condition for which to compute the
loss.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
"""
# If the condition is not regularized, or is a supervised (target)
# condition, use the standard loss computation
if condition.name not in self.regularized_conditions or isinstance(
condition, InputTargetCondition
):
return super()._compute_condition_loss(condition, data, batch_idx)
# Clone the input tensor if it exists to avoid in-place modifications
if "input" in data and hasattr(data["input"], "clone"):
data = dict(data)
data["input"] = data["input"].clone()
# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)
# Extract the temporal domain
time_domain = self.problem.temporal_domain
# Define the time steps for the causal loss computation
if time_domain.range:
time_steps = torch.linspace(
time_domain.range[self.temporal_variable][0],
time_domain.range[self.temporal_variable][1],
self.n_steps,
device=data["input"].device,
dtype=data["input"].dtype,
)
# If no range is defined, use the unique temporal value
else:
time_steps = torch.tensor(
[time_domain.fixed[self.temporal_variable]],
device=data["input"].device,
dtype=data["input"].dtype,
)
# Initialize the list to store the loss for each time step
time_loss = []
# Iterate over the time steps
for step in time_steps:
# Append the temporal variable to the spatial input tensor
spatial_pts = data["input"].extract(self.spatial_variables)
time_pts = LabelTensor(
torch.ones(spatial_pts.shape[0], 1, device=spatial_pts.device)
* step,
labels=[self.temporal_variable],
)
pts = {
"input": LabelTensor.cat(
[spatial_pts, time_pts], dim=1
).requires_grad_(True)
}
# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(pts, self)
# Retrieve condition name for more complex weighting schemes
condition_name = (
condition.name if hasattr(condition, "name") else None
)
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)
# Optional regularization hook
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)
# Append the loss for the current time step to the list
time_loss.append(condition_tensor_loss)
# Compute the time-adaptive weights for the causal loss
time_loss = torch.stack(time_loss)
with torch.no_grad():
weights = self._compute_weights(time_loss)
# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(weights * time_loss)
return condition_scalar_loss
def _compute_weights(self, loss):
"""
Compute the temporal adaptive weights for the causal loss based on the
cumulative loss.
:param LabelTensor loss: The physics loss values.
:return: The computed weights for the physics loss.
:rtype: LabelTensor
"""
# Compute the cumulative loss and apply exponential decay
cumulative_loss = torch.cumsum(loss, dim=0)
return torch.exp(-self.eps * cumulative_loss)
@property
def temporal_variable(self):
"""
The temporal variable of the problem.
:return: The temporal variable name.
:rtype: str
:raises ValueError: If the problem does not have exactly one temporal
variable.
"""
# Extract the temporal variable from the problem
temporal_variables = self.problem.temporal_variables
# Raise error if there is not exactly one temporal variable
if len(temporal_variables) != 1:
raise ValueError(
"Causal physics-informed solvers require exactly one temporal "
f"variable. Got {temporal_variables}."
)
return temporal_variables[0]
@property
def spatial_variables(self):
"""
The spatial variables of the problem.
:return: The spatial variable names.
:rtype: list[str]
:raises ValueError: If the problem does not have any spatial variables.
"""
# Determine the spatial variables by excluding the temporal variable
spatial_variables = [
v
for v in self.problem.input_variables
if v != self.temporal_variable
]
# Raise error if there are no spatial variables left
if not spatial_variables:
raise ValueError(
"Causal physics-informed solvers require at least one spatial "
"variable in addition to time."
)
return spatial_variables