Autoregressive Mixin#

Module for the autoregressive mixin class.

class AutoregressiveMixin[source]

Bases: object

Mixin that enables the autoregressive rollout loss logic by maintaining a running average of step losses and computing adaptive weights for each step based on the cumulative loss. This allows the solver to focus more on steps that are currently underperforming, which can help improve training stability and convergence.

Designed to be used in combination with any solver inheriting from BaseSolver.

on_train_epoch_start()[source]

Clear the running average and step count at the start of each epoch if reset_weights_at_epoch_start is True.

preprocess_step(current_state, **kwargs)[source]

Preprocess the current state before each step.

Parameters:
  • current_state (torch.Tensor | LabelTensor) – The current state tensor.

  • kwargs (dict) – Additional keyword arguments for preprocessing.

Returns:

The preprocessed state tensor.

Return type:

torch.Tensor | LabelTensor

postprocess_step(predicted_state, **kwargs)[source]

Postprocess the predicted state after each step. If multiple models are used, average the predictions across the model dimension.

Parameters:
  • predicted_state (torch.Tensor | LabelTensor) – The predicted state tensor.

  • kwargs (dict) – Additional keyword arguments for postprocessing.

Returns:

The postprocessed state tensor.

Return type:

torch.Tensor | LabelTensor

predict(initial_state, n_steps, **kwargs)[source]

Generate predictions by recursively calling the model’s forward.

Parameters:
  • initial_state (torch.Tensor | LabelTensor) – The initial state from which to start prediction. The initial state must be of shape [trajectories, 1, *features].

  • n_steps (int) – The number of autoregressive steps to predict.

  • kwargs (dict) – Additional keyword arguments.

Raises:

ValueError – If the provided initial_state tensor has less than 3 dimensions.

Returns:

The predicted trajectory, including the initial state. It has shape [trajectories, n_steps + 1, *features], where the first step corresponds to the initial state.

Return type:

torch.Tensor | LabelTensor