Source code for pina.solver.physics_informed_solver.pinn_interface
"""Module for the Physics-Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod
import torch
from ..supervised_solver import SupervisedSolverInterface
from ...condition import (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
[docs]
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
"""
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
the :class:`~pina.solver.solver.SolverInterface` class.
The `PINNInterface` class can be used to define PINNs that work with one or
multiple optimizers and/or models. By default, it is compatible with
problems defined by :class:`~pina.problem.abstract_problem.AbstractProblem`,
and users can choose the problem type the solver is meant to address.
"""
accepted_conditions_types = (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)
def __init__(self, **kwargs):
"""
Initialization of the :class:`PINNInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
class.
"""
kwargs["use_lt"] = True
super().__init__(**kwargs)
# current condition name
self.__metric = None
[docs]
def optimization_cycle(self, batch, loss_residuals=None):
"""
The optimization cycle for the PINN solver.
This method allows to call `_run_optimization_cycle` with the physics
loss as argument, thus distinguishing the training step from the
validation and test steps.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
# which losses to use
if loss_residuals is None:
loss_residuals = self.loss_phys
# compute optimization cycle
condition_loss = {}
for condition_name, points in batch:
self.__metric = condition_name
# if equations are passed
if "target" not in points:
input_pts = points["input"]
condition = self.problem.conditions[condition_name]
loss = loss_residuals(
input_pts.requires_grad_(), condition.equation
)
# if data are passed
else:
input_pts = points["input"]
output_pts = points["target"]
loss = self.loss_data(
input=input_pts.requires_grad_(), target=output_pts
)
# append loss
condition_loss[condition_name] = loss
return condition_loss
[docs]
@torch.set_grad_enabled(True)
def validation_step(self, batch):
"""
The validation step for the PINN solver. It returns the average residual
computed with the ``loss`` function not aggregated.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
return super().validation_step(
batch, loss_residuals=self._residual_loss
)
[docs]
@torch.set_grad_enabled(True)
def test_step(self, batch):
"""
The test step for the PINN solver. It returns the average residual
computed with the ``loss`` function not aggregated.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
return super().test_step(batch, loss_residuals=self._residual_loss)
[docs]
def loss_data(self, input, target):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method should
be overridden by the derived class.
:param LabelTensor input: The input to the neural network.
:param LabelTensor target: The target to compare with the
network's output.
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor
:raises NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError(
"PINN is being used in a supervised learning context, but the "
"'loss_data' method has not been implemented. "
)
[docs]
@abstractmethod
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the physics-informed solver based on the
provided samples and equation. This method must be overridden in
subclasses. It distinguishes different types of PINN solvers.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
"""
[docs]
def compute_residual(self, samples, equation):
"""
Compute the residuals of the equation.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual of the solution of the model.
:rtype: LabelTensor
"""
try:
residual = equation.residual(samples, self.forward(samples))
except TypeError:
# this occurs when the function has three inputs (inverse problem)
residual = equation.residual(
samples, self.forward(samples), self._params
)
return residual
def _residual_loss(self, samples, equation):
"""
Computes the physics loss for the physics-informed solver based on the
provided samples and equation. This method should never be overridden
by the user, if not intentionally,
since it is used internally to compute validation loss.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual loss.
:rtype: torch.Tensor
"""
residuals = self.compute_residual(samples, equation)
return self._loss_fn(residuals, torch.zeros_like(residuals))
@property
def current_condition_name(self):
"""
The current condition name.
:return: The current condition name.
:rtype: str
"""
return self.__metric