Autoregressive Mixin#
Module for the autoregressive mixin class.
- class AutoregressiveMixin[source]
Bases:
objectMixin that enables the autoregressive rollout loss logic by maintaining a running average of step losses and computing adaptive weights for each step based on the cumulative loss. This allows the solver to focus more on steps that are currently underperforming, which can help improve training stability and convergence.
Designed to be used in combination with any solver inheriting from
BaseSolver.- on_train_epoch_start()[source]
Clear the running average and step count at the start of each epoch if
reset_weights_at_epoch_startisTrue.
- preprocess_step(current_state, **kwargs)[source]
Preprocess the current state before each step.
- Parameters:
current_state (torch.Tensor | LabelTensor) – The current state tensor.
kwargs (dict) – Additional keyword arguments for preprocessing.
- Returns:
The preprocessed state tensor.
- Return type:
- postprocess_step(predicted_state, **kwargs)[source]
Postprocess the predicted state after each step. If multiple models are used, average the predictions across the model dimension.
- Parameters:
predicted_state (torch.Tensor | LabelTensor) – The predicted state tensor.
kwargs (dict) – Additional keyword arguments for postprocessing.
- Returns:
The postprocessed state tensor.
- Return type:
- predict(initial_state, n_steps, **kwargs)[source]
Generate predictions by recursively calling the model’s forward.
- Parameters:
initial_state (torch.Tensor | LabelTensor) – The initial state from which to start prediction. The initial state must be of shape
[trajectories, 1, *features].n_steps (int) – The number of autoregressive steps to predict.
kwargs (dict) – Additional keyword arguments.
- Raises:
ValueError – If the provided initial_state tensor has less than 3 dimensions.
- Returns:
The predicted trajectory, including the initial state. It has shape
[trajectories, n_steps + 1, *features], where the first step corresponds to the initial state.- Return type: