Source code for pina._src.solver.autoregressive_ensemble_solver
"""Module for the autoregressive ensemble solver class."""
from pina._src.solver.mixin.autoregressive_mixin import AutoregressiveMixin
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.solver.ensemble_solver import EnsembleSolver
[docs]
class AutoregressiveEnsembleSolver(AutoregressiveMixin, EnsembleSolver):
r"""
Ensemble-model solver for autoregressive learning problems.
This solver learns the time evolution of dynamical systems using an
ensemble of models. It is intended for problems defined by time-series data
and accepts only
:class:`~pina._src.condition.time_series_condition.TimeSeriesCondition`.
Given a sequence of states :math:`\{\mathbf{u}_t\}_{t=0}^{T}`, the solver
trains an ensemble of models :math:`\{\mathcal{M}_j\}_{j=1}^{M}` to predict
the next state from the current one. The prediction of each model is
.. math::
\hat{\mathbf{u}}_{t+1}^{(j)} = \mathcal{M}_j(\mathbf{u}_t),
\qquad j = 1, \ldots, M.
The autoregressive training objective minimizes the discrepancy between
the predicted states :math:`\hat{\mathbf{u}}_{t+1}^{(j)}` and the target
states :math:`\mathbf{u}_{t+1}` over the sequence and across the ensemble:
.. math::
\mathcal{L}_{\mathrm{problem}} = \frac{1}{M} \sum_{j=1}^{M}
\frac{1}{T} \sum_{t=0}^{T-1} \mathcal{L}
\left( \mathbf{u}_{t+1} - \hat{\mathbf{u}}_{t+1}^{(j)} \right),
where :math:`\mathcal{L}` is the selected loss function, typically the
mean squared error.
The solver supports adaptive weighting of autoregressive steps through the
``eps`` parameter. During training, each autoregressive step can contribute
differently to the total loss depending on its accumulated difficulty. Steps
with larger running losses are assigned larger weights, so that the solver
focuses more on parts of the rollout where prediction errors tend to
accumulate. The parameter ``eps`` controls the strength of this effect:
``eps = 0`` disables adaptive weighting, while larger values increase the
influence of high-loss steps on the final training objective.
"""
# Accepted conditions types for this solver
accepted_conditions_types = (TimeSeriesCondition,)
def __init__(
self,
problem,
models,
optimizers=None,
schedulers=None,
weighting=None,
loss=None,
use_lt=False,
eps=0.0,
reset_weights_at_epoch_start=True,
kwargs=None,
):
"""
Initialization of the :class:`AutoregressiveEnsembleSolver` class.
:param BaseProblem problem: The problem to be solved.
:param models: The model or list of models used by the solver.
:type models: torch.nn.Module | list[torch.nn.Module]
:param optimizers: The optimizer or list of optimizers used by the
solver. If ``None``, the ``torch.optim.Adam`` optimizer with a
learning rate of ``0.001`` is used for each model.
Default is ``None``.
:type optimizers: TorchOptimizer | list[TorchOptimizer]
:param schedulers: The scheduler or list of schedulers used by the
solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR``
scheduler with a factor of ``1.0`` is used for each model.
Default is ``None``.
:type schedulers: TorchScheduler | list[TorchScheduler]
:param BaseWeighting weighting: The weighting strategy used to combine
condition losses. If ``None``, no weighting is applied. Default is
``None``.
:param loss: The loss function used to compute residual losses.
If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``False``.
:param eps: The hyperparameter controlling the influence of the
cumulative loss on the adaptive weights. Higher values of eps will
lead to more aggressive weighting of steps with higher cumulative
loss. Default is ``0.0``.
:type eps: float | int
:param bool reset_weights_at_epoch_start: Whether to reset the running
average and step count at the start of each epoch. If ``True``, the
adaptive weights will be recalibrated at the beginning of each epoch
based on the new training dynamics. Default is ``True``.
:param dict kwargs: Additional keyword arguments for preprocessing and
postprocessing steps.
"""
# Initialize the parent class
EnsembleSolver.__init__(
self,
problem=problem,
models=models,
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
loss=loss,
use_lt=use_lt,
)
# Initialize the autoregressive components
self._init_autoregressive_components(
eps=eps,
reset_weights_at_epoch_start=reset_weights_at_epoch_start,
kwargs=kwargs,
)