Source code for pina.solver.solver

"""Solver module."""

from abc import ABCMeta, abstractmethod
import lightning
import torch

from torch._dynamo import OptimizedModule
from ..problem import AbstractProblem, InverseProblem
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
from ..loss import WeightingInterface
from ..loss.scalar_weighting import _NoWeighting
from ..utils import check_consistency, labelize_forward


[docs] class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Abstract base class for PINA solvers. All specific solvers must inherit from this interface. This class extends :class:`~lightning.pytorch.core.LightningModule`, providing additional functionalities for defining and optimizing Deep Learning models. By inheriting from this base class, solvers gain access to built-in training loops, logging utilities, and optimization techniques. """ def __init__(self, problem, weighting, use_lt): """ Initialization of the :class:`SolverInterface` class. :param AbstractProblem problem: The problem to be solved. :param WeightingInterface weighting: The weighting schema to be used. If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ super().__init__() # check consistency of the problem check_consistency(problem, AbstractProblem) self._check_solver_consistency(problem) self._pina_problem = problem # check consistency of the weighting and hook the condition names if weighting is None: weighting = _NoWeighting() check_consistency(weighting, WeightingInterface) self._pina_weighting = weighting weighting.condition_names = list(self._pina_problem.conditions.keys()) # check consistency use_lt check_consistency(use_lt, bool) self._use_lt = use_lt # if use_lt is true add extract operation in input if use_lt is True: self.forward = labelize_forward( forward=self.forward, input_variables=problem.input_variables, output_variables=problem.output_variables, ) # PINA private attributes (some are overridden by derived classes) self._pina_problem = problem self._pina_models = None self._pina_optimizers = None self._pina_schedulers = None # inverse problem handling if isinstance(self.problem, InverseProblem): self._params = self.problem.unknown_parameters self._clamp_params = self._clamp_inverse_problem_params else: self._params = None self._clamp_params = lambda: None
[docs] @abstractmethod def forward(self, *args, **kwargs): """ Abstract method for the forward pass implementation. :param args: The input tensor. :type args: torch.Tensor | LabelTensor | Data | Graph :param dict kwargs: Additional keyword arguments. """
[docs] @abstractmethod def optimization_cycle(self, batch): """ The optimization cycle for the solvers. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The losses computed for all conditions in the batch, casted to a subclass of :class:`torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict """
[docs] def training_step(self, batch, **kwargs): """ Solver training step. It computes the optimization cycle and aggregates the losses using the ``weighting`` attribute. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param dict kwargs: Additional keyword arguments passed to ``optimization_cycle``. :return: The loss of the training step. :rtype: torch.Tensor """ loss = self._optimization_cycle(batch=batch, **kwargs) self.store_log("train_loss", loss, self.get_batch_size(batch)) return loss
[docs] def validation_step(self, batch, **kwargs): """ Solver validation step. It computes the optimization cycle and averages the losses. No aggregation using the ``weighting`` attribute is performed. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param dict kwargs: Additional keyword arguments passed to ``optimization_cycle``. :return: The loss of the training step. :rtype: torch.Tensor """ losses = self.optimization_cycle(batch=batch, **kwargs) loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) self.store_log("val_loss", loss, self.get_batch_size(batch)) return loss
[docs] def test_step(self, batch, **kwargs): """ Solver test step. It computes the optimization cycle and averages the losses. No aggregation using the ``weighting`` attribute is performed. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param dict kwargs: Additional keyword arguments passed to ``optimization_cycle``. :return: The loss of the training step. :rtype: torch.Tensor """ losses = self.optimization_cycle(batch=batch, **kwargs) loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) self.store_log("test_loss", loss, self.get_batch_size(batch)) return loss
[docs] def store_log(self, name, value, batch_size): """ Store the log of the solver. :param str name: The name of the log. :param torch.Tensor value: The value of the log. :param int batch_size: The size of the batch. """ self.log( name=name, value=value, batch_size=batch_size, **self.trainer.logging_kwargs, )
[docs] def setup(self, stage): """ This method is called at the start of the train and test process to compile the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. """ if stage == "fit" and self.trainer.compile: self._setup_compile() if stage == "test" and ( self.trainer.compile and not self._is_compiled() ): self._setup_compile() return super().setup(stage)
def _is_compiled(self): """ Check if the model is compiled. :return: ``True`` if the model is compiled, ``False`` otherwise. :rtype: bool """ for model in self._pina_models: if not isinstance(model, OptimizedModule): return False return True def _setup_compile(self): """ Compile all models in the solver using ``torch.compile``. This method iterates through each model stored in the solver list and attempts to compile them for optimized execution. It supports models of type `torch.nn.Module` and `torch.nn.ModuleDict`. For models stored in a `ModuleDict`, each submodule is compiled individually. Models on Apple Silicon (MPS) use the 'eager' backend, while others use 'inductor'. :raises RuntimeError: If a model is neither `torch.nn.Module` nor `torch.nn.ModuleDict`. """ for i, model in enumerate(self._pina_models): if isinstance(model, torch.nn.ModuleDict): for name, module in model.items(): self._pina_models[i][name] = self._compile_modules(module) elif isinstance(model, torch.nn.Module): self._pina_models[i] = self._compile_modules(model) else: raise RuntimeError( "Compilation available only for " "torch.nn.Module or torch.nn.ModuleDict." ) def _check_solver_consistency(self, problem): """ Check the consistency of the solver with the problem formulation. :param AbstractProblem problem: The problem to be solved. """ for condition in problem.conditions.values(): check_consistency(condition, self.accepted_conditions_types) def _optimization_cycle(self, batch, **kwargs): """ Aggregate the loss for each condition in the batch. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param dict kwargs: Additional keyword arguments passed to ``optimization_cycle``. :return: The losses computed for all conditions in the batch, casted to a subclass of :class:`torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict """ # compute losses losses = self.optimization_cycle(batch) # clamp unknown parameters in InverseProblem (if needed) self._clamp_params() # store log for name, value in losses.items(): self.store_log( f"{name}_loss", value.item(), self.get_batch_size(batch) ) # aggregate loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) return loss def _clamp_inverse_problem_params(self): """ Clamps the parameters of the inverse problem solver to specified ranges. """ for v in self._params: self._params[v].data.clamp_( self.problem.unknown_parameter_domain.range_[v][0], self.problem.unknown_parameter_domain.range_[v][1], ) @staticmethod def _compile_modules(model): """ Perform the compilation of the model. This method attempts to compile the given PyTorch model using ``torch.compile`` to improve execution performance. The backend is selected based on the device on which the model resides: ``eager`` is used for MPS devices (Apple Silicon), and ``inductor`` is used for all others. If compilation fails, the method prints the error and returns the original, uncompiled model. :param torch.nn.Module model: The model to compile. :raises Exception: If the compilation fails. :return: The compiled model. :rtype: torch.nn.Module """ model_device = next(model.parameters()).device try: if model_device == torch.device("mps:0"): model = torch.compile(model, backend="eager") else: model = torch.compile(model, backend="inductor") except Exception as e: print("Compilation failed, running in normal mode.:\n", e) return model
[docs] @staticmethod def get_batch_size(batch): """ Get the batch size. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The size of the batch. :rtype: int """ batch_size = 0 for data in batch: batch_size += len(data[1]["input"]) return batch_size
[docs] @staticmethod def default_torch_optimizer(): """ Set the default optimizer to :class:`torch.optim.Adam`. :return: The default optimizer. :rtype: Optimizer """ return TorchOptimizer(torch.optim.Adam, lr=0.001)
[docs] @staticmethod def default_torch_scheduler(): """ Set the default scheduler to :class:`torch.optim.lr_scheduler.ConstantLR`. :return: The default scheduler. :rtype: Scheduler """ return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
@property def problem(self): """ The problem instance. :return: The problem instance. :rtype: :class:`~pina.problem.abstract_problem.AbstractProblem` """ return self._pina_problem @property def use_lt(self): """ Using LabelTensors as input during training. :return: The use_lt attribute. :rtype: bool """ return self._use_lt @property def weighting(self): """ The weighting schema. :return: The weighting schema. :rtype: :class:`~pina.loss.weighting_interface.WeightingInterface` """ return self._pina_weighting
[docs] class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): """ Base class for PINA solvers using a single :class:`torch.nn.Module`. """ def __init__( self, problem, model, optimizer=None, scheduler=None, weighting=None, use_lt=True, ): """ Initialization of the :class:`SingleSolverInterface` class. :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: The scheduler to be used. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ if optimizer is None: optimizer = self.default_torch_optimizer() if scheduler is None: scheduler = self.default_torch_scheduler() super().__init__(problem=problem, use_lt=use_lt, weighting=weighting) # check consistency of models argument and encapsulate in list check_consistency(model, torch.nn.Module) # check scheduler consistency and encapsulate in list check_consistency(scheduler, Scheduler) # check optimizer consistency and encapsulate in list check_consistency(optimizer, Optimizer) # initialize the model (needed by Lightining to go to different devices) self._pina_models = torch.nn.ModuleList([model]) self._pina_optimizers = [optimizer] self._pina_schedulers = [scheduler]
[docs] def forward(self, x): """ Forward pass implementation. :param x: Input tensor. :type x: torch.Tensor | LabelTensor | Graph | Data :return: Solver solution. :rtype: torch.Tensor | LabelTensor | Graph | Data """ return self.model(x)
[docs] def configure_optimizers(self): """ Optimizer configuration for the solver. :return: The optimizer and the scheduler :rtype: tuple[list[Optimizer], list[Scheduler]] """ self.optimizer.hook(self.model.parameters()) if isinstance(self.problem, InverseProblem): self.optimizer.instance.add_param_group( { "params": [ self._params[var] for var in self.problem.unknown_variables ] } ) self.scheduler.hook(self.optimizer) return ([self.optimizer.instance], [self.scheduler.instance])
@property def model(self): """ The model used for training. :return: The model used for training. :rtype: torch.nn.Module """ return self._pina_models[0] @property def scheduler(self): """ The scheduler used for training. :return: The scheduler used for training. :rtype: Scheduler """ return self._pina_schedulers[0] @property def optimizer(self): """ The optimizer used for training. :return: The optimizer used for training. :rtype: Optimizer """ return self._pina_optimizers[0]
[docs] class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): """ Base class for PINA solvers using multiple :class:`torch.nn.Module`. """ def __init__( self, problem, models, optimizers=None, schedulers=None, weighting=None, use_lt=True, ): """ Initialization of the :class:`MultiSolverInterface` class. :param AbstractProblem problem: The problem to be solved. :param models: The neural network models to be used. :type model: list[torch.nn.Module] | tuple[torch.nn.Module] :param list[Optimizer] optimizers: The optimizers to be used. If ``None``, the :class:`torch.optim.Adam` optimizer is used for all models. Default is ``None``. :param list[Scheduler] schedulers: The schedulers to be used. If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used for all the models. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. If ``None``, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. :raises ValueError: If the models are not a list or tuple with length greater than one. .. warning:: :class:`MultiSolverInterface` uses manual optimization by setting ``automatic_optimization=False`` in :class:`~lightning.pytorch.core.LightningModule`. For more information on manual optimization please see `here <https://lightning.ai/docs/pytorch/stable/\ model/manual_optimization.html>`_. """ if not isinstance(models, (list, tuple)) or len(models) < 2: raise ValueError( "models should be list[torch.nn.Module] or " "tuple[torch.nn.Module] with len greater than " "one." ) if optimizers is None: optimizers = [ self.default_torch_optimizer() for _ in range(len(models)) ] if schedulers is None: schedulers = [ self.default_torch_scheduler() for _ in range(len(models)) ] if any(opt is None for opt in optimizers): optimizers = [ self.default_torch_optimizer() if opt is None else opt for opt in optimizers ] if any(sched is None for sched in schedulers): schedulers = [ self.default_torch_scheduler() if sched is None else sched for sched in schedulers ] super().__init__(problem=problem, use_lt=use_lt, weighting=weighting) # check consistency of models argument and encapsulate in list check_consistency(models, torch.nn.Module) # check scheduler consistency and encapsulate in list check_consistency(schedulers, Scheduler) # check optimizer consistency and encapsulate in list check_consistency(optimizers, Optimizer) # check length consistency optimizers if len(models) != len(optimizers): raise ValueError( "You must define one optimizer for each model." f"Got {len(models)} models, and {len(optimizers)}" " optimizers." ) if len(schedulers) != len(optimizers): raise ValueError( "You must define one scheduler for each optimizer." f"Got {len(schedulers)} schedulers, and {len(optimizers)}" " optimizers." ) # initialize the model self._pina_models = torch.nn.ModuleList(models) self._pina_optimizers = optimizers self._pina_schedulers = schedulers # Set automatic optimization to False. # For more information on manual optimization see: # http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html self.automatic_optimization = False
[docs] def on_train_batch_end(self, outputs, batch, batch_idx): """ This method is called at the end of each training batch and overrides the PyTorch Lightning implementation to log checkpoints. :param torch.Tensor outputs: The ``model``'s output for the current batch. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param int batch_idx: The index of the current batch. """ # increase by one the counter of optimization to save loggers epoch_loop = self.trainer.fit_loop.epoch_loop epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 return super().on_train_batch_end(outputs, batch, batch_idx)
[docs] def configure_optimizers(self): """ Optimizer configuration for the solver. :return: The optimizer and the scheduler :rtype: tuple[list[Optimizer], list[Scheduler]] """ for optimizer, scheduler, model in zip( self.optimizers, self.schedulers, self.models ): optimizer.hook(model.parameters()) scheduler.hook(optimizer) return ( [optimizer.instance for optimizer in self.optimizers], [scheduler.instance for scheduler in self.schedulers], )
@property def models(self): """ The models used for training. :return: The models used for training. :rtype: torch.nn.ModuleList """ return self._pina_models @property def optimizers(self): """ The optimizers used for training. :return: The optimizers used for training. :rtype: list[Optimizer] """ return self._pina_optimizers @property def schedulers(self): """ The schedulers used for training. :return: The schedulers used for training. :rtype: list[Scheduler] """ return self._pina_schedulers