Supervised Ensemble Solver#

Module for the supervised ensemble-model solver class.

class SupervisedEnsembleSolver(problem, models, optimizers=None, schedulers=None, weighting=None, loss=None, use_lt=True)[source]

Bases: EnsembleSolver

Ensemble-model solver for supervised learning problems.

This solver approximates the mapping between input data and target data using an ensemble of models. It is intended for problems whose conditions are defined by input-target pairs and accepts only InputTargetCondition.

Given input samples \(\mathbf{s}_i\), target values \(\mathbf{u}_i\), and an ensemble of models \(\{\mathcal{M}_j\}_{j=1}^{M}\), the prediction of each model is

\[\hat{\mathbf{u}}_{i}^{(j)} = \mathcal{M}_j(\mathbf{s}_i), \qquad j = 1, \ldots, M.\]

The supervised training objective minimizes the discrepancy between the target values and the ensemble predictions:

\[\mathcal{L}_{\mathrm{problem}} = \frac{1}{M} \sum_{j=1}^{M} \frac{1}{N} \sum_{i=1}^{N} \mathcal{L} \left( \mathbf{u}_i - \hat{\mathbf{u}}_{i}^{(j)} \right),\]

where \(\mathcal{L}\) is the selected loss function, typically the mean squared error.

Initialization of the SupervisedEnsembleSolver class.

Parameters:
  • problem (BaseProblem) – The problem to be solved.

  • models (torch.nn.Module | list[torch.nn.Module]) – The model or list of models used by the solver.

  • optimizers (TorchOptimizer | list[TorchOptimizer]) – The optimizer or list of optimizers used by the solver. If None, the torch.optim.Adam optimizer with a learning rate of 0.001 is used for each model. Default is None.

  • schedulers (TorchScheduler | list[TorchScheduler]) – The scheduler or list of schedulers used by the solver. If None, the torch.optim.lr_scheduler.ConstantLR scheduler with a factor of 1.0 is used for each model. Default is None.

  • weighting (BaseWeighting) – The weighting strategy used to combine condition losses. If None, no weighting is applied. Default is None.

  • loss – The loss function used to compute residual losses. If None, torch.nn.MSELoss is used. Default is None.

  • use_lt (bool) – If True, the solver uses LabelTensors as input. Default is True.