Source code for pina._src.solver.mixin.residual_based_attention_mixin
"""Module for the residual-based attention mixin class."""
import torch
from pina._src.core.utils import check_consistency
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
)
[docs]
class ResidualBasedAttentionMixin:
"""
Mixin that augments the residual loss with an attention mechanism based on
the residual values.
The attention weights are computed as a function of the residuals, and they
are used to weight the contribution of each condition to the overall loss.
This allows the solver to focus more on conditions with larger residuals,
potentially improving convergence and accuracy.
Designed to be used in combination with any solver inheriting from
:class:`~pina._src.solver.base_solver.BaseSolver`.
"""
def _init_residual_attention_components(
self, eta=0.001, gamma=0.999, regularized_conditions=None
):
"""
Initialize the residual-based attention parameters.
:param eta: The learning rate for the residual-based attention weights
update. Default is ``0.001``.
:type eta: float | int
:param float gamma: The decay factor for the residual-based attention
mechanism. Default is ``0.999``.
:param regularized_conditions: The names of the conditions that should
receive attention regularization. If ``None``, all conditions are
regularized. Default is ``None``.
:type regularized_conditions: str | list[str]
:raises ValueError: If ``eta`` is not a positive float or int.
:raises ValueError: If ``gamma`` is not a float in the range (0, 1).
:raises ValueError: If ``regularized_conditions`` is not a string or a
list of strings.
:raises ValueError: If any of the specified ``regularized_conditions``
are not present in the ``problem``'s conditions.
"""
# Use all conditions if regularized_conditions is None
if regularized_conditions is None:
regularized_conditions = list(self.problem.conditions.keys())
# Check consistency
check_consistency(eta, (float, int))
check_consistency(gamma, float)
check_consistency(regularized_conditions, str)
# Assert gamma is in range (0, 1)
if not 0 < gamma < 1:
raise ValueError("gamma must be in range (0, 1)")
# Assert eta is positive
if eta <= 0:
raise ValueError("eta must be positive")
# Map conditions to list if a single condition is provided
if not isinstance(regularized_conditions, (list, tuple)):
regularized_conditions = [regularized_conditions]
# Ensure that all regularized conditions are present in the problem
problem_conditions = set(self.problem.conditions.keys())
for condition in regularized_conditions:
if condition not in problem_conditions:
raise ValueError(
f"Condition '{condition}' is not present in the problem."
)
# Initialize residual-based attention parameters
self.regularized_conditions = regularized_conditions
self.gamma = gamma
self.eta = eta
self.weight_buffers = {}
# Iterate over all conditions to initialize the attention weights
for cond in self.regularized_conditions:
# Get the condition object
condition = self.problem.conditions[cond]
# Determine the number of points in the condition
if isinstance(condition, DomainEquationCondition):
n_pts = self.problem._discretised_domains[cond].shape[0]
else:
n_pts = condition.data.input.shape[0]
# Register the attention weights as a buffer in the module
self.register_buffer(f"weight_{cond}", torch.zeros((n_pts, 1)))
self.weight_buffers[cond] = f"weight_{cond}"
def _regularize_condition_loss(
self,
condition_tensor_loss,
condition_name,
data,
batch_idx,
):
"""
Regularize the condition loss if needed. This method can be overridden
by mixins to implement specific regularization strategies, such as
adding a gradient penalty in gradient-enhanced solvers or applying
residual-based attention.
:param condition_tensor_loss: The original tensor loss for the
condition.
:type condition_tensor_loss: torch.Tensor | LabelTensor
:param str condition_name: The name of the condition.
:param dict data: The data corresponding to the condition.
:param int batch_idx: The index of the current batch.
:return: The regularized tensor loss for the condition.
:rtype: torch.Tensor | LabelTensor
"""
# Apply residual-based attention mechanism if needed
if condition_name in self.regularized_conditions:
# Compute the normalized residuals norm for the current condition
res_abs = torch.linalg.vector_norm(
self.residual_tensor, ord=2, dim=1, keepdim=True
)
res_norm = res_abs / (res_abs.max() + 1e-12)
# Get the correct indices to retrieve the weights for the batch
len_residuals = self.residual_tensor.shape[0]
# Get the weights buffer for the current condition
weights = getattr(self, self.weight_buffers[condition_name])
# Get the total number of points, together with the start / end idx
total_points = weights.shape[0]
start = (batch_idx * len_residuals) % total_points
end = start + len_residuals
# Retrieve the weights for the current batch using modular indexing
idx = torch.arange(start, end, device=weights.device)
idx = idx % total_points
# Update weights
with torch.no_grad():
weights[idx] = self.gamma * weights[idx] + self.eta * res_norm
# Weight the condition tensor loss with attention weights
condition_tensor_loss = condition_tensor_loss * weights[idx]
return condition_tensor_loss