Ensemble Solver#

Module for the ensemble solver class.

class EnsembleSolver(problem, models, optimizers=None, schedulers=None, weighting=None, loss=None, use_lt=True)[source]

Bases: ManualOptimizationMixin, EnsembleMixin, ConditionAggregatorMixin, BaseSolver

Base class for implementing ensemble-model solvers.

This class provides the standard starting point for solvers based on an ensemble of models. It combines the shared solver machinery from BaseSolver with ensemble-model handling, manual optimization, and condition-wise loss aggregation.

Subclasses can inherit from this class to implement solver-specific behavior while reusing the common logic for model registration, optimizer and scheduler setup, manual optimization, loss evaluation, weighting, and aggregation across problem conditions.

Initialization of the EnsembleSolver class.

Parameters:
  • problem (BaseProblem) – The problem to be solved.

  • models (torch.nn.Module | list[torch.nn.Module]) – The model or list of models used by the solver.

  • optimizers (TorchOptimizer | list[TorchOptimizer]) – The optimizer or list of optimizers used by the solver. If None, the torch.optim.Adam optimizer with a learning rate of 0.001 is used for each model. Default is None.

  • schedulers (TorchScheduler | list[TorchScheduler]) – The scheduler or list of schedulers used by the solver. If None, the torch.optim.lr_scheduler.ConstantLR scheduler with a factor of 1.0 is used for each model. Default is None.

  • weighting (BaseWeighting) – The weighting strategy used to combine condition losses. If None, no weighting is applied. Default is None.

  • loss – The loss function used to compute residual losses. If None, torch.nn.MSELoss is used. Default is None.

  • use_lt (bool) – If True, the solver uses LabelTensors as input. Default is True.