Source code for pina.callback.refinement.r3_refinement

"""Module for the R3Refinement callback."""

import torch
from .refinement_interface import RefinementInterface
from ...label_tensor import LabelTensor
from ...utils import check_consistency
from ...loss import LossInterface


[docs] class R3Refinement(RefinementInterface): """ PINA Implementation of the R3 Refinement Callback. This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those with low residuals. Points are sampled uniformly in all regions where sampling is needed. .. seealso:: Original Reference: Daw, Arka, et al. *Mitigating Propagation Failures in Physics-informed Neural Networks using Retain-Resample-Release (R3) Sampling. (2023)*. DOI: `10.48550/arXiv.2207.02338 <https://doi.org/10.48550/arXiv.2207.02338>`_ :Example: >>> r3_callback = R3Refinement(sample_every=5) """ def __init__( self, sample_every, residual_loss=torch.nn.L1Loss, condition_to_update=None, ): """ Initialization of the :class:`R3Refinement` callback. :param int sample_every: The sampling frequency. :param loss: The loss function to compute the residuals. Default is :class:`~torch.nn.L1Loss`. :type loss: LossInterface | :class:`~torch.nn.modules.loss._Loss` :param condition_to_update: The conditions to update during the refinement process. If None, all conditions will be updated. Default is None. :type condition_to_update: list(str) | tuple(str) | str :raises ValueError: If the condition_to_update is neither a string nor an iterable of strings. :raises TypeError: If the residual_loss is not a subclass of :class:`~torch.nn.Module`. """ super().__init__(sample_every, condition_to_update) # Check consistency check_consistency( residual_loss, (LossInterface, torch.nn.modules.loss._Loss), subclass=True, ) # Save loss function self.loss_fn = residual_loss(reduction="none")
[docs] def sample(self, current_points, condition_name, solver): """ Sample new points based on the R3 refinement strategy. :param current_points: The current points in the domain. :type current_points: LabelTensor | torch.Tensor :param str condition_name: The name of the condition to update. :param PINNInterface solver: The solver using this callback. :return: The new samples generated by the R3 strategy. :rtype: LabelTensor """ # Retrieve condition and current points device = solver.trainer.strategy.root_device condition = solver.problem.conditions[condition_name] current_points = current_points.to(device).requires_grad_(True) # Compute residuals for the given condition (averaged over all fields) target = solver.compute_residual(current_points, condition.equation) residuals = self.loss_fn(target, torch.zeros_like(target)).mean( dim=tuple(range(1, target.ndim)) ) # Retrieve domain and initial population size domain_name = solver.problem.conditions[condition_name].domain domain = solver.problem.domains[domain_name] num_old_points = self.initial_population_size[condition_name] # Select points with residual above the mean mask = (residuals > residuals.mean()).flatten() if mask.any(): high_residual_pts = current_points[mask] high_residual_pts.labels = current_points.labels samples = domain.sample(num_old_points - len(high_residual_pts)) return LabelTensor.cat([high_residual_pts, samples.to(device)]) return domain.sample(num_old_points, "random")