Source code for pina._src.solver.mixin.physics_informed_mixin
"""Module for the physics-informed mixin class."""
import torch
[docs]
class PhysicsInformedMixin:
"""
Mixin that enables physics-informed training by ensuring gradients are
enabled during validation and testing, which is necessary for computing
physics residuals.
Designed to be used in combination with any solver inheriting from
:class:`~pina._src.solver.base_solver.BaseSolver`.
"""
[docs]
@torch.enable_grad()
def validation_step(self, batch, batch_idx):
"""
Solver validation step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
return super().validation_step(batch, batch_idx)
[docs]
@torch.enable_grad()
def test_step(self, batch, batch_idx):
"""
Solver test step.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
return super().test_step(batch, batch_idx)