Source code for pina._src.solver.supervised_single_model_solver

"""Module for the supervised single-model solver class."""

from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.solver.single_model_solver import SingleModelSolver


[docs] class SupervisedSingleModelSolver(SingleModelSolver): r""" Single-model solver for supervised learning problems. This solver is designed for problems defined by input-target pairs and uses a single model to approximate the mapping from input variables to target variables. It supports only :class:`~pina._src.condition.input_target_condition.InputTargetCondition` conditions. Given a model :math:`\mathcal{M}`, the solver minimizes the discrepancy between the target values :math:`\mathbf{u}_i` and the model predictions :math:`\mathcal{M}(\mathbf{s}_i)` evaluated at the input data :math:`\mathbf{s}_i`. The supervised loss minimized during training is .. math:: \mathcal{L}_{\mathrm{problem}} = \frac{1}{N} \sum_{i=1}^{N} \mathcal{L} \left( \mathbf{u}_i - \mathcal{M}(\mathbf{s}_i) \right), where :math:`\mathcal{L}` is the selected loss function, typically the mean squared error. """ # Accepted conditions types for this solver accepted_conditions_types = (InputTargetCondition,) def __init__( self, problem, model, optimizer=None, scheduler=None, weighting=None, loss=None, use_lt=True, ): """ Initialization of the :class:`SupervisedSingleModelSolver` class. :param BaseProblem problem: The problem to be solved. :param torch.nn.Module model: The model used by the solver. :param TorchOptimizer optimizer: The optimizer used by the solver. If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate of ``0.001`` is used. Default is ``None``. :param TorchScheduler scheduler: The scheduler used by the solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler with a factor of ``1.0`` is used. Default is ``None``. :param BaseWeighting weighting: The weighting strategy used to combine condition losses. If ``None``, no weighting is applied. Default is ``None``. :param loss: The loss function used to compute residual losses. If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. Default is ``True``. """ SingleModelSolver.__init__( self, problem=problem, model=model, optimizer=optimizer, scheduler=scheduler, weighting=weighting, loss=loss, use_lt=use_lt, )