Source code for pina._src.solver.base_solver

"""Module for the base solver class."""

from abc import ABCMeta
import lightning
import torch
from pina._src.core.utils import labelize_forward, check_consistency
from pina._src.solver.solver_interface import SolverInterface
from pina._src.weighting.base_weighting import BaseWeighting
from pina._src.problem.inverse_problem import InverseProblem
from pina._src.optim.torch_optimizer import TorchOptimizer
from pina._src.optim.torch_scheduler import TorchScheduler
from pina._src.weighting.no_weighting import _NoWeighting
from pina._src.problem.base_problem import BaseProblem
from pina._src.loss.base_dual_loss import BaseDualLoss


[docs] class BaseSolver(SolverInterface, metaclass=ABCMeta): """ Base class for all solvers, implementing common functionality. All solvers must inherit from this class and implement abstract methods defined in :class:`~pina.solver.solver_interface.SolverInterface`. This class is not meant to be instantiated directly. """ # Define the available reductions for loss computation _AVAILABLE_REDUCTIONS = { "none": lambda x: x, "mean": lambda x: x.mean(), "sum": lambda x: x.sum(), } def __init__(self, problem, use_lt=True): """ Initialization of the :class:`BaseSolver` class. :param BaseProblem problem: The problem to be solved. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. Default is ``True``. :raises ValueError: If ``use_lt`` is not a boolean. :raises ValueError: If ``problem`` is not an instance of :class:`~pina.problem.base_problem.BaseProblem`. :raises ValueError: If one or more problem conditions are not supported by the solver. """ # Reset the solver state self.reset() # Call the parent class initializer lightning.pytorch.LightningModule.__init__(self) # Check consistency check_consistency(use_lt, bool) check_consistency(problem, BaseProblem) for condition in problem.conditions.values(): check_consistency(condition, self.accepted_conditions_types) # Initialize the solver components self._pina_problem = problem self._use_lt = use_lt # Manage InverseProblem parameters if needed if isinstance(self.problem, InverseProblem): self._params = self.problem.unknown_parameters self._clamp_params = self._clamp_inverse_problem_params else: self._params = None self._clamp_params = lambda: None # Labelize the forward method if using LabelTensors if self.use_lt: self.forward = labelize_forward( forward=self.forward, input_variables=problem.input_variables, output_variables=problem.output_variables, )
[docs] def reset(self): """ Reset the internal solver state, clearing the stored problem, models, optimizers and schedulers. """ self._pina_problem = None self._pina_models = None self._pina_optimizers = None self._pina_schedulers = None
def _clamp_inverse_problem_params(self): """ Clamp the unknown parameters of an inverse problem. Each unknown parameter is constrained to lie within the corresponding bounds defined by the inverse problem parameter domain. """ for v in self._params: self._params[v].data.clamp_( self.problem.unknown_parameter_domain.range[v][0], self.problem.unknown_parameter_domain.range[v][1], ) def _init_weighting_and_loss(self, weighting=None, loss=None): """ Initialize the weighting strategy and loss function. :param BaseWeighting weighting: The weighting strategy used to combine condition losses. If ``None``, no weighting is applied. Default is ``None``. :param loss: The loss function used to compute residual losses. If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``. :type loss: torch.nn.Module | BaseDualLoss :raises ValueError: If ``weighting`` is not an instance of :class:`~pina.weighting.base_weighting.BaseWeighting`. :raises ValueError: If ``loss`` is not a valid PyTorch loss or :class:`~pina.loss.base_dual_loss.BaseDualLoss`. """ # If no weighting schema is provided, use a default no-weighting schema if weighting is None: weighting = _NoWeighting() # Set default loss function to MSE if not provided if loss is None: loss = torch.nn.MSELoss() # Check consistency check_consistency(weighting, BaseWeighting) check_consistency(loss, (BaseDualLoss, torch.nn.modules.loss._Loss)) # Store the weighting and loss function for use in the solver self._pina_weighting = weighting weighting._solver = self self._loss_fn = loss self._reduction = getattr(loss, "reduction", "mean") if hasattr(self._loss_fn, "reduction"): self._loss_fn.reduction = "none" def _init_solver_components( self, models, optimizers=None, schedulers=None, ): """ Initialize the solver models, optimizers and schedulers. :param models: The model or list of models used by the solver. :type models: torch.nn.Module | list[torch.nn.Module] :param optimizers: The optimizer or list of optimizers used by the solver. If ``None``, the ``torch.optim.Adam`` optimizer with a learning rate of ``0.001`` is used for each model. Default is ``None``. :type optimizers: TorchOptimizer | list[TorchOptimizer] :param schedulers: The scheduler or list of schedulers used by the solver. If ``None``, the ``torch.optim.lr_scheduler.ConstantLR`` scheduler with a factor of ``1.0`` is used for each model. Default is ``None``. :type schedulers: TorchScheduler | list[TorchScheduler] :raises ValueError: If ``models`` are not instances of :class:`torch.nn.Module`. :raises ValueError: If ``optimizers`` are not instances of :class:`~pina.optim.torch_optimizer.TorchOptimizer`. :raises ValueError: If ``schedulers`` are not instances of :class:`~pina.optim.torch_scheduler.TorchScheduler`. :raises ValueError: If the number of optimizers does not match that of models. :raises ValueError: If the number of schedulers does not match that of models. """ # Helper function to map single items to lists if needed _to_list = lambda x: [x] if not isinstance(x, (list, tuple)) else x # Map models to list if a single model is provided models = _to_list(models) # Set default optimizers to Adam if not provided if optimizers is None: optimizers = [ TorchOptimizer(torch.optim.Adam, lr=0.001) for _ in range(len(models)) ] # Set default schedulers to ConstantLR if not provided if schedulers is None: schedulers = [ TorchScheduler(torch.optim.lr_scheduler.ConstantLR, factor=1.0) for _ in range(len(models)) ] # Map optimizers and schedulers to lists if single items are provided optimizers = _to_list(optimizers) schedulers = _to_list(schedulers) # Check consistency check_consistency(optimizers, TorchOptimizer) check_consistency(schedulers, TorchScheduler) check_consistency(models, torch.nn.Module) # Check that the number of optimizers matches the number of models if len(optimizers) != len(models): raise ValueError( "You must define one optimizer for each model." f"Got {len(models)} models, and {len(optimizers)} optimizers." ) # Check that the number of schedulers matches the number of models if len(schedulers) != len(models): raise ValueError( "You must define one scheduler for each model." f"Got {len(models)} models, and {len(schedulers)} schedulers." ) # Initialize the solver components self._pina_models = torch.nn.ModuleList(models) self._pina_optimizers = optimizers self._pina_schedulers = schedulers
[docs] def training_step(self, batch, batch_idx): """ Solver training step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param int batch_idx: The index of the current batch. :return: The loss of the training step. :rtype: torch.Tensor """ loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx) self.log( name="train_loss", value=loss.item(), batch_size=self.get_batch_size(batch), **self.trainer.logging_kwargs, ) return loss
[docs] def validation_step(self, batch, batch_idx): """ Solver validation step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param int batch_idx: The index of the current batch. :return: The loss of the training step. :rtype: torch.Tensor """ loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx) self.log( name="val_loss", value=loss.item(), batch_size=self.get_batch_size(batch), **self.trainer.logging_kwargs, ) return loss
[docs] def test_step(self, batch, batch_idx): """ Solver test step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :param int batch_idx: The index of the current batch. :return: The loss of the training step. :rtype: torch.Tensor """ loss = self.batch_evaluation_step(batch=batch, batch_idx=batch_idx) self.log( name="test_loss", value=loss.item(), batch_size=self.get_batch_size(batch), **self.trainer.logging_kwargs, ) return loss
def _compute_condition_loss(self, condition, data, batch_idx): """ Compute the scalar loss for a given condition and its data. :param BaseCondition condition: The condition for which to compute the loss. :param dict data: The data corresponding to the condition. :param int batch_idx: The index of the current batch. :return: The scalar loss for the condition. :rtype: torch.Tensor """ # Clone the input tensor if it exists to avoid in-place modifications if "input" in data and hasattr(data["input"], "clone"): data = dict(data) data["input"] = data["input"].clone() # Prepare condition data, e.g. by enabling gradient for regularizations data = self._prepare_condition_data(data=data) # Compute and store the residual tensor for the condition self.residual_tensor = condition.evaluate(data, self) # Retrieve condition name for more complex weighting schemes condition_name = condition.name if hasattr(condition, "name") else None # Compute the tensor loss from the residual tensor condition_tensor_loss = self._loss_from_residual(condition_name) # Optional regularization hook, e.g gradient-enhanced or residual-based condition_tensor_loss = self._regularize_condition_loss( condition_tensor_loss=condition_tensor_loss, condition_name=condition_name, data=data, batch_idx=batch_idx, ) # Compute the scalar loss from the tensor loss and return it condition_scalar_loss = self._apply_reduction(condition_tensor_loss) return condition_scalar_loss def _prepare_condition_data(self, data): """ Prepare the condition data for loss computation. This method can be overridden by mixins to implement specific data preparation steps, such as enabling gradient tracking for inputs in gradient-enhanced solvers. :param dict data: The original condition data. :return: The prepared condition data. :rtype: dict """ return data def _regularize_condition_loss( self, condition_tensor_loss, condition_name, data, batch_idx, ): """ Regularize the condition loss if needed. This method can be overridden by mixins to implement specific regularization strategies, such as adding a gradient penalty in gradient-enhanced solvers or applying residual-based attention. :param condition_tensor_loss: The original tensor loss for the condition. :type condition_tensor_loss: torch.Tensor | LabelTensor :param str condition_name: The name of the condition. :param dict data: The data corresponding to the condition. :param int batch_idx: The index of the current batch. :return: The regularized tensor loss for the condition. :rtype: torch.Tensor | LabelTensor """ return condition_tensor_loss def _loss_from_residual(self, condition_name=None): """ Compute the tensor loss from the residual tensor. :param str condition_name: The name of the condition. :return: The tensor loss computed from the residual tensor. :rtype: torch.Tensor | LabelTensor """ # Compute the loss tensor and appply reduction return self._loss_fn( self.residual_tensor, torch.zeros_like(self.residual_tensor) ) def _apply_reduction(self, value): """ Apply the specified reduction to the loss tensor. :param value: The loss tensor to reduce. :type value: torch.Tensor | LabelTensor :return: The reduced loss. :rtype: torch.Tensor | LabelTensor """ # Get the reduction function based on the specified reduction type reduction_fn = self._AVAILABLE_REDUCTIONS.get(self._reduction) # If the reduction type is not supported, raise an error if reduction_fn is None: raise ValueError( f"Unsupported reduction '{self._reduction}'. " f"Available options include {self._AVAILABLE_REDUCTIONS.keys()}" ) return reduction_fn(value)
[docs] @staticmethod def get_batch_size(batch): """ Get the batch size. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The size of the batch. :rtype: int """ return sum(len(data[1]["input"]) for data in batch)
@property def problem(self): """ The problem instance. :return: The problem instance. :rtype: :class:`~pina.problem.base_problem.BaseProblem` """ return self._pina_problem @property def use_lt(self): """ Using LabelTensors as input during training. :return: The use_lt attribute. :rtype: bool """ return self._use_lt @property def weighting(self): """ The weighting schema used by the solver. :return: The weighting schema used by the solver. :rtype: :class:`~pina.weighting.base_weighting.BaseWeighting` """ return self._pina_weighting @property def loss(self): """ The element-wise loss module used by the solver. :return: The element-wise loss module used by the solver. :rtype: torch.nn.Module """ return self._loss_fn