Source code for pina._src.solver.self_adaptive_physics_informed_solver
"""Module for the self-adaptive physics-informed multi-model solver."""
import torch
from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.solver.multi_model_solver import MultiModelSolver
from pina._src.core.utils import check_consistency
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
)
[docs]
class SelfAdaptivePhysicsInformedSolver(PhysicsInformedMixin, MultiModelSolver):
r"""
Multi-model solver for self-adaptive physics-informed learning problems.
This solver approximates the solution of a differential problem using a
trainable model together with condition-wise self-adaptive weights. It is
intended for problems whose conditions may include supervised data, equation
residuals evaluated on input points, and equation residuals sampled from
domains.
Given a model :math:`\mathcal{M}`, the predicted solution is
.. math::
\hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
For each condition, the solver introduces trainable pointwise weights. These
weights are passed through a user-defined weight function :math:`m`,
typically chosen to keep the effective weights bounded or positive. The
resulting weighted objective encourages the model to focus on regions where
the residual is larger.
For a problem with governing equation operator :math:`\mathcal{A}` in the
domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
boundary :math:`\partial\Omega`, the objective can be written as
.. math::
\mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
\sum_{i=1}^{N_{\Omega}} m(\lambda_{\Omega}^{i}) \mathcal{L}
\left( \mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i) \right)
+ \frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
m(\lambda_{\partial\Omega}^{i})
\mathcal{L} \left( \mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i) \right),
where :math:`\lambda_{\Omega}^{i}` and :math:`\lambda_{\partial\Omega}^{i}`
are the self-adaptive weights associated with points in :math:`\Omega` and
:math:`\partial\Omega`, respectively, and :math:`\mathcal{L}` is the
selected loss function, typically the mean squared error.
The model parameters and the self-adaptive weights are optimized through a
min-max problem:
.. math::
\min_{\theta} \max_{\lambda} \mathcal{L}_{\mathrm{problem}},
where :math:`\theta` denotes the model parameters and :math:`\lambda`
denotes the collection of self-adaptive weights.
.. seealso::
**Original reference**: McClenny, L. D., & Braga-Neto, U. M. (2023).
*Self-adaptive physics-informed neural networks.*
Journal of Computational Physics, 474, 111722.
DOI: `10.1016/j.jcp.2022.111722
<https://doi.org/10.1016/j.jcp.2022.111722>`_.
"""
# Accepted conditions types for this solver
accepted_conditions_types = (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
def __init__(
self,
problem,
model,
weight_function=torch.nn.Sigmoid(),
optimizer_model=None,
optimizer_weights=None,
scheduler_model=None,
scheduler_weights=None,
weighting=None,
loss=None,
):
"""
Initialization of the :class:`SelfAdaptivePhysicsInformedSolver` class.
:param BaseProblem problem: The problem to be solved.
:param torch.nn.Module model: The model used by the solver.
:param torch.nn.Module weight_function: The weight function used to
compute self-adaptive weights. Default is ``torch.nn.Sigmoid()``.
:param TorchOptimizer optimizer_model: The optimizer of the main model.
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
of ``0.001`` is used. Default is ``None``.
:param TorchOptimizer optimizer_weights: The optimizer of the
self-adaptive weights. If ``None``, the ``torch.optim.Adam``
optimizer with a learning rate of ``0.001`` is used.
Default is ``None``.
:param TorchScheduler scheduler_model: The scheduler of the main model.
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
with a factor of ``1.0`` is used. Default is ``None``.
:param TorchScheduler scheduler_weights: The scheduler of the
self-adaptive weights. If ``None``, the
``torch.optim.lr_scheduler.ConstantLR`` scheduler with a factor of
``1.0`` is used. Default is ``None``.
:param BaseWeighting weighting: The weighting strategy used to combine
condition losses. If ``None``, no weighting is applied. Default is
``None``.
:param loss: The loss function used to compute residual losses.
If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
:raises ValueError: If ``weight_function`` is not a ``torch.nn.Module``.
:raises ValueError: If not all domains have been discretised.
"""
# Check consistency
check_consistency(weight_function, torch.nn.Module)
# Check that all domains have been discretised
if not problem.are_all_domains_discretised:
raise ValueError(
"All domains must be discretised before initializing the "
"solver."
)
# Compute the number of points for each condition
num_points = {
cond: (
problem._discretised_domains[cond].shape[0]
if isinstance(problem.conditions[cond], DomainEquationCondition)
else problem.conditions[cond].data.input.shape[0]
)
for cond in problem.conditions
}
# Initialize weights container and per-condition parameters
weights = torch.nn.Module()
# Attach the weight function as a submodule
weights.func = weight_function
# Register a torch.nn.Parameter for each condition to store the weights
for cond in problem.conditions:
p = torch.nn.Parameter(torch.zeros(num_points[cond], 1))
setattr(weights, cond, p)
# Prepare optimizers
optimizers = (
[optimizer_model, optimizer_weights]
if any(o is not None for o in (optimizer_model, optimizer_weights))
else None
)
# Prepare schedulers
schedulers = (
[scheduler_model, scheduler_weights]
if any(s is not None for s in (scheduler_model, scheduler_weights))
else None
)
# Initialize the base solver
MultiModelSolver.__init__(
self,
problem=problem,
models=[model, weights],
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
loss=loss,
use_lt=True,
)
[docs]
def training_step(self, batch, batch_idx):
"""
Solver training step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
# Zero the gradients of weights optimizer and compute the loss
self.optimizer_weights.instance.zero_grad()
loss = self.batch_evaluation_step(batch, batch_idx)
# Perform the backward pass and complete a step for the weights
self.manual_backward(-loss)
self.optimizer_weights.instance.step()
self.scheduler_weights.instance.step()
# Zero the gradients of model optimizer and compute the loss again
self.optimizer_model.instance.zero_grad()
loss = self.batch_evaluation_step(batch, batch_idx)
# Perform the backward pass and complete a step for the model
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
# Log the training loss
self.log(
name="train_loss",
value=loss.item(),
batch_size=self.get_batch_size(batch),
**self.trainer.logging_kwargs,
)
return loss
[docs]
def forward(self, x):
"""
Forward pass through the model.
:param x: The input data.
:type x: torch.Tensor | LabelTensor | Data | Graph
:return: The output of the model.
:rtype: torch.Tensor | LabelTensor | Data | Graph
"""
return self.model(x)
def _compute_condition_loss(self, condition, data, batch_idx):
"""
Compute the scalar loss for a given condition and its data.
:param BaseCondition condition: The condition for which to compute the
loss.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
"""
# Clone the input tensor if it exists to avoid in-place modifications
if "input" in data and hasattr(data["input"], "clone"):
data = dict(data)
data["input"] = data["input"].clone()
# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)
# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)
# Retrieve condition name for more complex weighting schemes
condition_name = condition.name
# Apply the activation function to the condition-specific weights
weight_param = getattr(self.weights, condition_name)
weight_tensor = self.weights.func(weight_param)
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)
# Optional regularization hook, e.g gradient-enhanced or residual-based
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)
# Get the correct indices to retrieve the weights for the current batch
len_residuals = self.residual_tensor.shape[0]
# Get the total number of points, together with the start / end indices
total_points = weight_param.shape[0]
start = (batch_idx * len_residuals) % total_points
end = start + len_residuals
# Retrieve the weights for the current batch using modular indexing
idx = torch.arange(start, end, device=self.residual_tensor.device)
idx = idx % total_points
# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(
condition_tensor_loss * weight_tensor[idx]
)
return condition_scalar_loss
@property
def model(self):
"""
The single model used by the solver.
:return: The single model used by the solver.
:rtype: torch.nn.Module
"""
return self._pina_models[0]
@property
def weights(self):
"""
The self-adaptive weights used by the solver.
:return: The self-adaptive weights used by the solver.
:rtype: torch.nn.Module
"""
return self._pina_models[1]
@property
def optimizer_model(self):
"""
The optimizer for the model used by the solver.
:return: The optimizer for the model used by the solver.
:rtype: TorchOptimizer
"""
return self.optimizers[0]
@property
def optimizer_weights(self):
"""
The optimizer for the weights used by the solver.
:return: The optimizer for the weights used by the solver.
:rtype: TorchOptimizer
"""
return self.optimizers[1]
@property
def scheduler_model(self):
"""
The scheduler for the model used by the solver.
:return: The scheduler for the model used by the solver.
:rtype: TorchScheduler
"""
return self.schedulers[0]
@property
def scheduler_weights(self):
"""
The scheduler for the weights used by the solver.
:return: The scheduler for the weights used by the solver.
:rtype: TorchScheduler
"""
return self.schedulers[1]