Source code for pina._src.solver.mixin.gradient_enhanced_mixin

"""Module for the gradient-enhanced mixin class."""

import torch
from pina._src.problem.spatial_problem import SpatialProblem
from pina._src.core.utils import check_consistency
from pina._src.core.operator import grad


[docs] class GradientEnhancedMixin: """ Mixin that augments residual losses with a gradient-based regularization term. The additional penalty is the norm of the residual gradient with respect to the spatial input variables. It is only applied to the conditions whose names are listed in ``regularized_conditions``. Designed to be used in combination with any solver inheriting from :class:`~pina._src.solver.base_solver.BaseSolver` and using :class:`~pina._src.core.tensor.label_tensor.LabelTensor` as input. """ def _init_gradient_enhanced_components( self, regularization_weight=1.0, regularized_conditions=None ): """ Initialize the gradient-enhancement parameters. :param regularization_weight: The weight of the gradient regularization term. Default is ``1.0``. :type regularization_weight: float | int :param regularized_conditions: The names of the conditions that should receive gradient regularization. If ``None``, all conditions are regularized. Default is ``None``. :type regularized_conditions: str | list[str] :raises ValueError: If ``regularization_weight`` is not a float or int. :raises ValueError: If ``regularized_conditions`` is not a string or a list of strings. :raises ValueError: If ``problem`` is not an instance of :class:`~pina._src.problem.spatial_problem.SpatialProblem`. :raises ValueError: If the solver's input data are not instances of :class:`~pina._src.core.tensor.label_tensor.LabelTensor`. :raises ValueError: If any of the specified ``regularized_conditions`` are not present in the ``problem``'s conditions. """ # Use all conditions if regularized_conditions is None if regularized_conditions is None: regularized_conditions = list(self.problem.conditions.keys()) # Check consistency check_consistency(regularization_weight, (float, int)) check_consistency(regularized_conditions, str) # Map conditions to list if a single condition is provided if not isinstance(regularized_conditions, (list, tuple)): regularized_conditions = [regularized_conditions] # Assert the problem is instance of SpatialProblem if not isinstance(self.problem, SpatialProblem): raise ValueError( "Gradient-enhanced regularization requires the problem to be " f"an instance of SpatialProblem. Got {type(self.problem)}." ) # Ensure that the solver is using LabelTensors as input if not self.use_lt: raise ValueError( "Gradient-enhanced regularization requires the solver to use " f"LabelTensors as input. Got use_lt={self.use_lt}." ) # Ensure that all regularized conditions are present in the problem problem_conditions = set(self.problem.conditions.keys()) for condition in regularized_conditions: if condition not in problem_conditions: raise ValueError( f"Condition '{condition}' is not present in the problem." ) # Initialize the gradient-enhancement parameters self.regularization_weight = regularization_weight self.regularized_conditions = regularized_conditions def _prepare_condition_data(self, data): """ Prepare the condition data for loss computation. This method can be overridden by mixins to implement specific data preparation steps, such as enabling gradient tracking for inputs in gradient-enhanced solvers. :param dict data: The original condition data. :return: The prepared condition data. :rtype: dict """ # If data does not require grad, force requires_grad to True if "input" in data and not data["input"].requires_grad: data["input"].requires_grad_(True) return data def _regularize_condition_loss( self, condition_tensor_loss, condition_name, data, batch_idx, ): """ Regularize the condition loss if needed. This method can be overridden by mixins to implement specific regularization strategies, such as adding a gradient penalty in gradient-enhanced solvers or applying residual-based attention. :param condition_tensor_loss: The original tensor loss for the condition. :type condition_tensor_loss: torch.Tensor | LabelTensor :param str condition_name: The name of the condition. :param dict data: The data corresponding to the condition. :param int batch_idx: The index of the current batch. :return: The regularized tensor loss for the condition. :rtype: torch.Tensor | LabelTensor """ # Regularize the loss with the gradient penalty if needed if condition_name in self.regularized_conditions: # Apply labels to the residual tensor for gradient computation self.residual_tensor.labels = [ f"res_{i}" for i in range(self.residual_tensor.shape[1]) ] # Compute the gradient of the residual with respect to spatial input residual_gradient = grad( output_=self.residual_tensor, input_=data["input"], d=self.problem.spatial_variables, ) # Compute the norm of the residual gradient residual_gradient_norm = self._loss_fn( residual_gradient, torch.zeros_like(residual_gradient) ) # Compute the gradient penalty term penalty = self.regularization_weight * residual_gradient_norm # Add the gradient penalty to the original condition tensor loss condition_tensor_loss = condition_tensor_loss + penalty return condition_tensor_loss