Source code for pina._src.solver.mixin.multi_model_mixin

"""Module for the multi-model mixin class."""

import torch
from pina._src.problem.inverse_problem import InverseProblem


[docs] class MultiModelMixin: """ Mixin that defines the forward pass and optimizer configuration for solvers backed by multiple models. Provides properties to access the models, optimizers, and schedulers. Designed to be used in combination with any solver inheriting from :class:`~pina._src.solver.base_solver.BaseSolver`. """
[docs] def forward(self, x): """ The forward pass implementation that evaluates all models and returns a stacked tensor of their outputs. :param x: The input data. :type x: torch.Tensor | LabelTensor | Data | Graph :return: The output of all models stacked together. :rtype: torch.Tensor | LabelTensor | Data | Graph """ return torch.stack( [self.models[idx](x) for idx in range(self.num_models)] )
[docs] def configure_optimizers(self): """ Configure the optimizers and schedulers for all models. :return: The optimizer and the scheduler :rtype: tuple[list[TorchOptimizer], list[TorchScheduler]] """ # Iterate over models, optimizers, and schedulers to hook them together for optimizer, scheduler, model in zip( self.optimizers, self.schedulers, self.models ): # Hook the optimizer to the model parameters optimizer.hook(model.parameters()) # Add parameter group for inverse problems if needed if isinstance(self.problem, InverseProblem): optimizer.instance.add_param_group( { "params": [ self._params[var] for var in self.problem.unknown_variables ] } ) # Hook the scheduler to the optimizer scheduler.hook(optimizer) return ( [optimizer.instance for optimizer in self.optimizers], [scheduler.instance for scheduler in self.schedulers], )
@property def models(self): """ The models used by the solver. :return: The models used by the solver. :rtype: list[torch.nn.Module] """ return self._pina_models @property def optimizers(self): """ The optimizers used by the solver. :return: The optimizers used by the solver. :rtype: list[TorchOptimizer] """ return self._pina_optimizers @property def schedulers(self): """ The schedulers used by the solver. :return: The schedulers used by the solver. :rtype: list[TorchScheduler] """ return self._pina_schedulers @property def num_models(self): """ The number of models used by the solver. :return: The number of models used by the solver. :rtype: int """ return len(self.models)