Source code for pina._src.solver.mixin.ensemble_mixin
"""Module for the ensemble mixin class."""
import torch
from pina._src.solver.base_solver import BaseSolver
from pina._src.solver.mixin.multi_model_mixin import MultiModelMixin
[docs]
class EnsembleMixin(MultiModelMixin):
"""
Mixin that defines the forward pass and optimizer configuration for solvers
backed by an ensemble of models. Provides properties to access the models,
optimizers, and schedulers.
Designed to be used in combination with any solver inheriting from
:class:`~pina._src.solver.base_solver.BaseSolver`.
"""
[docs]
def forward(self, x):
"""
Forward pass for ensemble solvers. If an active model index is set, only
that model is evaluated. Otherwise, all models are evaluated and their
outputs are stacked together.
:param x: The input data.
:type x: torch.Tensor | LabelTensor | Data | Graph
:return: The output of all models stacked together.
:rtype: torch.Tensor | LabelTensor | Data | Graph
"""
# Retrieve the index of the active model if set
active_idx = getattr(self, "_active_model_idx", None)
# If an active model index is set, evaluate only that model
if active_idx is not None:
return self.models[active_idx](x)
# Otherwise, evaluate all models and stack outputs
return torch.stack(
[self.models[idx](x) for idx in range(self.num_models)]
)
def _compute_condition_loss(self, condition, data, batch_idx):
"""
Compute the scalar loss for a given condition and its data.
:param BaseCondition condition: The condition for which to compute the
loss.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The scalar loss for the condition.
:rtype: torch.Tensor
"""
# Initialize model losses for the current condition
model_losses = []
# Restore the active model index if it was set, else set it to None
previous_active_model_idx = getattr(self, "_active_model_idx", None)
# Try - finally to ensure active model index is always restored
try:
# Iterate over all ensemble models to compute individual losses
for model_idx in range(self.num_models):
# Set the active model index for the current iteration
self._active_model_idx = model_idx
# Compute the scalar loss for the current model and condition
condition_scalar_loss = BaseSolver._compute_condition_loss(
self, condition, data, batch_idx
)
# Store the computed loss for the current model
model_losses.append(condition_scalar_loss)
# Ensure that the active model index is always restored
finally:
# Restore the previous active model index after computation
self._active_model_idx = previous_active_model_idx
return torch.stack(model_losses).mean()