Refinments callbacks#

class R3Refinement(sample_every, residual_loss=<class 'torch.nn.modules.loss.L1Loss'>, condition_to_update=None)[source]#

Bases: RefinementInterface

PINA Implementation of the R3 Refinement Callback.

This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those with low residuals. Points are sampled uniformly in all regions where sampling is needed.

See also

Original Reference: Daw, Arka, et al. Mitigating Propagation Failures in Physics-informed Neural Networks using Retain-Resample-Release (R3) Sampling. (2023). DOI: 10.48550/arXiv.2207.02338

Example:

>>> r3_callback = R3Refinement(sample_every=5)

Initialization of the R3Refinement callback.

Parameters:
  • sample_every (int) – The sampling frequency.

  • loss (LossInterface | _Loss) – The loss function to compute the residuals. Default is L1Loss.

  • condition_to_update (list(str) | tuple(str) | str) – The conditions to update during the refinement process. If None, all conditions will be updated. Default is None.

Raises:
  • ValueError – If the condition_to_update is neither a string nor an iterable of strings.

  • TypeError – If the residual_loss is not a subclass of Module.

sample(current_points, condition_name, solver)[source]#

Sample new points based on the R3 refinement strategy.

Parameters:
  • current_points (LabelTensor | torch.Tensor) – The current points in the domain.

  • condition_name (str) – The name of the condition to update.

  • solver (PINNInterface) – The solver using this callback.

Returns:

The new samples generated by the R3 strategy.

Return type:

LabelTensor