Ensemble Solver#
Module for the ensemble solver class.
- class EnsembleSolver(problem, models, optimizers=None, schedulers=None, weighting=None, loss=None, use_lt=True)[source]
Bases:
ManualOptimizationMixin,EnsembleMixin,ConditionAggregatorMixin,BaseSolverBase class for implementing ensemble-model solvers.
This class provides the standard starting point for solvers based on an ensemble of models. It combines the shared solver machinery from
BaseSolverwith ensemble-model handling, manual optimization, and condition-wise loss aggregation.Subclasses can inherit from this class to implement solver-specific behavior while reusing the common logic for model registration, optimizer and scheduler setup, manual optimization, loss evaluation, weighting, and aggregation across problem conditions.
Initialization of the
EnsembleSolverclass.- Parameters:
problem (BaseProblem) – The problem to be solved.
models (torch.nn.Module | list[torch.nn.Module]) – The model or list of models used by the solver.
optimizers (TorchOptimizer | list[TorchOptimizer]) – The optimizer or list of optimizers used by the solver. If
None, thetorch.optim.Adamoptimizer with a learning rate of0.001is used for each model. Default isNone.schedulers (TorchScheduler | list[TorchScheduler]) – The scheduler or list of schedulers used by the solver. If
None, thetorch.optim.lr_scheduler.ConstantLRscheduler with a factor of1.0is used for each model. Default isNone.weighting (BaseWeighting) – The weighting strategy used to combine condition losses. If
None, no weighting is applied. Default isNone.loss – The loss function used to compute residual losses. If
None,torch.nn.MSELossis used. Default isNone.use_lt (bool) – If
True, the solver uses LabelTensors as input. Default isTrue.