Autoregressive Ensemble Solver#

Module for the autoregressive ensemble solver class.

class AutoregressiveEnsembleSolver(problem, models, optimizers=None, schedulers=None, weighting=None, loss=None, use_lt=False, eps=0.0, reset_weights_at_epoch_start=True, kwargs=None)[source]

Bases: AutoregressiveMixin, EnsembleSolver

Ensemble-model solver for autoregressive learning problems.

This solver learns the time evolution of dynamical systems using an ensemble of models. It is intended for problems defined by time-series data and accepts only TimeSeriesCondition.

Given a sequence of states \(\{\mathbf{u}_t\}_{t=0}^{T}\), the solver trains an ensemble of models \(\{\mathcal{M}_j\}_{j=1}^{M}\) to predict the next state from the current one. The prediction of each model is

\[\hat{\mathbf{u}}_{t+1}^{(j)} = \mathcal{M}_j(\mathbf{u}_t), \qquad j = 1, \ldots, M.\]

The autoregressive training objective minimizes the discrepancy between the predicted states \(\hat{\mathbf{u}}_{t+1}^{(j)}\) and the target states \(\mathbf{u}_{t+1}\) over the sequence and across the ensemble:

\[\mathcal{L}_{\mathrm{problem}} = \frac{1}{M} \sum_{j=1}^{M} \frac{1}{T} \sum_{t=0}^{T-1} \mathcal{L} \left( \mathbf{u}_{t+1} - \hat{\mathbf{u}}_{t+1}^{(j)} \right),\]

where \(\mathcal{L}\) is the selected loss function, typically the mean squared error.

The solver supports adaptive weighting of autoregressive steps through the eps parameter. During training, each autoregressive step can contribute differently to the total loss depending on its accumulated difficulty. Steps with larger running losses are assigned larger weights, so that the solver focuses more on parts of the rollout where prediction errors tend to accumulate. The parameter eps controls the strength of this effect: eps = 0 disables adaptive weighting, while larger values increase the influence of high-loss steps on the final training objective.

Initialization of the AutoregressiveEnsembleSolver class.

Parameters:
  • problem (BaseProblem) – The problem to be solved.

  • models (torch.nn.Module | list[torch.nn.Module]) – The model or list of models used by the solver.

  • optimizers (TorchOptimizer | list[TorchOptimizer]) – The optimizer or list of optimizers used by the solver. If None, the torch.optim.Adam optimizer with a learning rate of 0.001 is used for each model. Default is None.

  • schedulers (TorchScheduler | list[TorchScheduler]) – The scheduler or list of schedulers used by the solver. If None, the torch.optim.lr_scheduler.ConstantLR scheduler with a factor of 1.0 is used for each model. Default is None.

  • weighting (BaseWeighting) – The weighting strategy used to combine condition losses. If None, no weighting is applied. Default is None.

  • loss – The loss function used to compute residual losses. If None, torch.nn.MSELoss is used. Default is None.

  • use_lt (bool) – If True, the solver uses LabelTensors as input. Default is False.

  • eps (float | int) – The hyperparameter controlling the influence of the cumulative loss on the adaptive weights. Higher values of eps will lead to more aggressive weighting of steps with higher cumulative loss. Default is 0.0.

  • reset_weights_at_epoch_start (bool) – Whether to reset the running average and step count at the start of each epoch. If True, the adaptive weights will be recalibrated at the beginning of each epoch based on the new training dynamics. Default is True.

  • kwargs (dict) – Additional keyword arguments for preprocessing and postprocessing steps.