Source code for pina._src.solver.competitive_physics_informed_solver
"""Module for the competitive physics-informed multi-model solver."""
import copy
from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.solver.multi_model_solver import MultiModelSolver
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
)
[docs]
class CompetitivePhysicsInformedSolver(PhysicsInformedMixin, MultiModelSolver):
r"""
Multi-model solver for competitive physics-informed learning problems.
This solver approximates the solution of a differential problem using a
trainable model together with a discriminator network. It is intended for
problems whose conditions may include supervised data, equation residuals
evaluated on input points, and equation residuals sampled from domains.
Given a model :math:`\mathcal{M}`, the predicted solution is
.. math::
\hat{\mathbf{u}}(\mathbf{x}) = \mathcal{M}(\mathbf{x}).
The discriminator :math:`D` assigns pointwise weights to the residuals,
encouraging the model to focus on regions where the approximation performs
poorly. The model parameters are optimized by minimizing the loss, while the
discriminator parameters are optimized by maximizing it.
For a problem with governing equation operator :math:`\mathcal{A}` in the
domain :math:`\Omega` and boundary operator :math:`\mathcal{B}` on the
boundary :math:`\partial\Omega`, the competitive objective can be written as
.. math::
\mathcal{L}_{\mathrm{problem}} = \frac{1}{N_{\Omega}}
\sum_{i=1}^{N_{\Omega}} \mathcal{L}
\left(D(\mathbf{x}_i)\mathcal{A}[\hat{\mathbf{u}}](\mathbf{x}_i)\right)
+\frac{1}{N_{\partial\Omega}} \sum_{i=1}^{N_{\partial\Omega}}
\mathcal{L}
\left(D(\mathbf{x}_i)\mathcal{B}[\hat{\mathbf{u}}](\mathbf{x}_i)\right),
where :math:`D` is the discriminator network and :math:`\mathcal{L}` is the
selected loss function, typically the mean squared error.
The model and discriminator are trained through a min-max problem:
.. math::
\min_{\theta} \max_{\phi} \mathcal{L}_{\mathrm{problem}},
where :math:`\theta` denotes the model parameters and :math:`\phi` denotes
the discriminator parameters.
.. seealso::
**Original reference**: Zeng, Q., Kothari, P., Chou, E., & Masi, G.
(2022).
*Competitive physics informed networks.*
International Conference on Learning Representations, ICLR 2022.
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
"""
# Accepted conditions types for this solver
accepted_conditions_types = (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
def __init__(
self,
problem,
model,
discriminator=None,
optimizer_model=None,
optimizer_discriminator=None,
scheduler_model=None,
scheduler_discriminator=None,
weighting=None,
loss=None,
):
"""
Initialization of the :class:`CompetitivePhysicsInformedSolver` class.
:param BaseProblem problem: The problem to be solved.
:param torch.nn.Module model: The model used by the solver.
:param torch.nn.Module discriminator: The discriminator used by the
solver. If ``None``, a deep copy of the model is used as
discriminator. Default is ``None``.
:param TorchOptimizer optimizer_model: The optimizer of the main model.
If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate
of ``0.001`` is used. Default is ``None``.
:param TorchOptimizer optimizer_discriminator: The optimizer of the
discriminator. If ``None``, the ``torch.optim.Adam`` optimizer with
a learning rate of ``0.001`` is used. Default is ``None``.
:param TorchScheduler scheduler_model: The scheduler of the main model.
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
with a factor of ``1.0`` is used. Default is ``None``.
:param TorchScheduler scheduler_discriminator: The scheduler of the
discriminator.
If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler
with a factor of ``1.0`` is used. Default is ``None``.
:param BaseWeighting weighting: The weighting strategy used to combine
condition losses. If ``None``, no weighting is applied. Default is
``None``.
:param loss: The loss function used to compute residual losses.
If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
:raises ValueError: If ``weight_function`` is not a ``torch.nn.Module``.
:raises ValueError: If not all domains have been discretised.
"""
# Initialize the discriminator if not provided
if discriminator is None:
discriminator = copy.deepcopy(model)
# Prepare optimizers
optimizers = (
[optimizer_model, optimizer_discriminator]
if any(
o is not None
for o in (optimizer_model, optimizer_discriminator)
)
else None
)
# Prepare schedulers
schedulers = (
[scheduler_model, scheduler_discriminator]
if any(
s is not None
for s in (scheduler_model, scheduler_discriminator)
)
else None
)
# Initialize the base solver
MultiModelSolver.__init__(
self,
problem=problem,
models=[model, discriminator],
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
loss=loss,
use_lt=True,
)
[docs]
def training_step(self, batch, batch_idx):
"""
Solver training step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
# Zero the gradients of the model optimizer and compute the loss
self.optimizer_model.instance.zero_grad()
loss = self.batch_evaluation_step(batch, batch_idx)
# Perform the backward pass and complete a step for the model
self.manual_backward(loss)
self.optimizer_model.instance.step()
self.scheduler_model.instance.step()
# Zero the gradients of the discriminator optimizer and compute the loss
self.optimizer_discriminator.instance.zero_grad()
loss = self.batch_evaluation_step(batch, batch_idx)
# Perform the backward pass and complete a step for the discriminator
self.manual_backward(-loss)
self.optimizer_discriminator.instance.step()
self.scheduler_discriminator.instance.step()
# Log the training loss
self.log(
name="train_loss",
value=loss.item(),
batch_size=self.get_batch_size(batch),
**self.trainer.logging_kwargs,
)
return loss
[docs]
def forward(self, x):
"""
Forward pass through the model.
:param x: The input data.
:type x: torch.Tensor | LabelTensor | Data | Graph
:return: The output of the model.
:rtype: torch.Tensor | LabelTensor | Data | Graph
"""
return self.model(x)
def _compute_condition_loss(self, condition, data, batch_idx):
"""
Compute the scalar loss for a given condition and its data.
:param BaseCondition condition: The condition for which to compute the
loss.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
"""
# Clone the input tensor if it exists to avoid in-place modifications
if "input" in data and hasattr(data["input"], "clone"):
data = dict(data)
data["input"] = data["input"].clone()
# Prepare condition data, e.g. by enabling gradient for regularizations
data = self._prepare_condition_data(data=data)
# Compute and store the residual tensor for the condition
self.residual_tensor = condition.evaluate(data, self)
# Compute the discriminator bets for the current condition
discriminator_input = data["input"][self.problem.input_variables]
discriminator_bets = self.discriminator(discriminator_input)
# Weight the residual tensor using the discriminator bets
self.residual_tensor = self.residual_tensor * discriminator_bets
# Retrieve condition name for more complex weighting schemes
condition_name = condition.name if hasattr(condition, "name") else None
# Compute the tensor loss from the residual tensor
condition_tensor_loss = self._loss_from_residual(condition_name)
# Optional regularization hook, e.g gradient-enhanced or residual-based
condition_tensor_loss = self._regularize_condition_loss(
condition_tensor_loss=condition_tensor_loss,
condition_name=condition_name,
data=data,
batch_idx=batch_idx,
)
# Compute the scalar loss from the tensor loss and return it
condition_scalar_loss = self._apply_reduction(condition_tensor_loss)
return condition_scalar_loss
@property
def model(self):
"""
The single model used by the solver.
:return: The single model used by the solver.
:rtype: torch.nn.Module
"""
return self._pina_models[0]
@property
def discriminator(self):
"""
The discriminator used by the solver.
:return: The discriminator used by the solver.
:rtype: torch.nn.Module
"""
return self._pina_models[1]
@property
def optimizer_model(self):
"""
The optimizer for the model used by the solver.
:return: The optimizer for the model used by the solver.
:rtype: TorchOptimizer
"""
return self.optimizers[0]
@property
def optimizer_discriminator(self):
"""
The optimizer for the discriminator used by the solver.
:return: The optimizer for the discriminator used by the solver.
:rtype: TorchOptimizer
"""
return self.optimizers[1]
@property
def scheduler_model(self):
"""
The scheduler for the model used by the solver.
:return: The scheduler for the model used by the solver.
:rtype: TorchScheduler
"""
return self.schedulers[0]
@property
def scheduler_discriminator(self):
"""
The scheduler for the discriminator used by the solver.
:return: The scheduler for the discriminator used by the solver.
:rtype: TorchScheduler
"""
return self.schedulers[1]