Source code for pina.solver.supervised_solver.supervised_solver_interface
"""Module for the Supervised solver interface."""
from abc import abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...condition import InputTargetCondition
[docs]
class SupervisedSolverInterface(SolverInterface):
r"""
Base class for Supervised solvers. This class implements a Supervised Solver
, using a user specified ``model`` to solve a specific ``problem``.
The ``SupervisedSolverInterface`` class can be used to define
Supervised solvers that work with one or multiple optimizers and/or models.
By default, it is compatible with problems defined by
:class:`~pina.problem.abstract_problem.AbstractProblem`,
and users can choose the problem type the solver is meant to address.
"""
accepted_conditions_types = InputTargetCondition
def __init__(self, loss=None, **kwargs):
"""
Initialization of the :class:`SupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.solver.SolverInterface` class.
"""
if loss is None:
loss = torch.nn.MSELoss()
super().__init__(**kwargs)
# check consistency
check_consistency(loss, (LossInterface, _Loss), subclass=False)
# assign variables
self._loss_fn = loss
[docs]
def optimization_cycle(self, batch):
"""
The optimization cycle for the solvers.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
condition_loss[condition_name] = self.loss_data(
input=points["input"], target=points["target"]
)
return condition_loss
[docs]
@abstractmethod
def loss_data(self, input, target):
"""
Compute the data loss for the Supervised. This method is abstract and
should be override by derived classes.
:param input: The input to the neural network.
:type input: LabelTensor | torch.Tensor | Graph | Data
:param target: The target to compare with the network's output.
:type target: LabelTensor | torch.Tensor | Graph | Data
:return: The supervised loss, averaged over the number of observations.
:rtype: LabelTensor | torch.Tensor | Graph | Data
"""
@property
def loss(self):
"""
The loss function to be minimized.
:return: The loss function to be minimized.
:rtype: torch.nn.Module
"""
return self._loss_fn