DeepEnsembleSolverInterface#

class DeepEnsembleSolverInterface(problem, models, optimizers=None, schedulers=None, weighting=None, use_lt=True, ensemble_dim=0)[source]#

Bases: MultiSolverInterface

A class for handling ensemble models in a multi-solver training framework. It allows for manual optimization, as well as the ability to train, validate, and test multiple models as part of an ensemble. The ensemble dimension can be customized to control how outputs are stacked.

By default, it is compatible with problems defined by AbstractProblem, and users can choose the problem type the solver is meant to address.

An ensemble model is constructed by combining multiple models that solve the same type of problem. Mathematically, this creates an implicit distribution \(p(\mathbf{u} \mid \mathbf{s})\) over the possible outputs \(\mathbf{u}\), given the original input \(\mathbf{s}\). The models \(\mathcal{M}_{i\in (1,\dots,r)}\) in the ensemble work collaboratively to capture different aspects of the data or task, with each model contributing a distinct prediction \(\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})\). By aggregating these predictions, the ensemble model can achieve greater robustness and accuracy compared to individual models, leveraging the diversity of the models to reduce overfitting and improve generalization. Furthemore, statistical metrics can be computed, e.g. the ensemble mean and variance:

\[\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}\]
\[\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r (\mathbf{y}_{i} - \mathbf{\mu})^2\]

See also

Original reference: Lakshminarayanan, B., Pritzel, A., & Blundell, C. (2017). Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in neural information processing systems, 30. DOI: arXiv:1612.01474.

Initialization of the DeepEnsembleSolverInterface class.

Parameters:
  • problem (AbstractProblem) – The problem to be solved.

  • models (torch.nn.Module) – The neural network models to be used.

  • optimizer (Optimizer) – The optimizer to be used. If None, the torch.optim.Adam optimizer is used. Default is None.

  • scheduler (Scheduler) – Learning rate scheduler. If None, the torch.optim.lr_scheduler.ConstantLR scheduler is used. Default is None.

  • weighting (WeightingInterface) – The weighting schema to be used. If None, no weighting schema is used. Default is None.

  • use_lt (bool) – If True, the solver uses LabelTensors as input. Default is True.

  • ensemble_dim (int) – The dimension along which the ensemble outputs are stacked. Default is 0.

forward(x, ensemble_idx=None)[source]#

Forward pass through the ensemble models. If an ensemble_idx is provided, it returns the output of the specific model corresponding to that index. If no index is given, it stacks the outputs of all models along the ensemble dimension.

Parameters:
  • x (LabelTensor) – The input tensor to the models.

  • ensemble_idx (int) – Optional index to select a specific model from the ensemble. If None results for all models are stacked in ensemble_dim dimension. Default is None.

Returns:

The output of the selected model or the stacked outputs from all models.

Return type:

LabelTensor

training_step(batch)[source]#

Training step for the solver, overridden for manual optimization. This method performs a forward pass, calculates the loss, and applies manual backward propagation and optimization steps for each model in the ensemble.

Parameters:

batch (list[tuple[str, dict]]) – A batch of training data. Each element is a tuple containing a condition name and a dictionary of points.

Returns:

The aggregated loss after the training step.

Return type:

torch.Tensor

property ensemble_dim#

The dimension along which the ensemble outputs are stacked.

Returns:

The ensemble dimension.

Return type:

int

property num_ensemble#

The number of models in the ensemble.

Returns:

The number of models in the ensemble.

Return type:

int