Refinments callbacks#
- class R3Refinement(sample_every, residual_loss=<class 'torch.nn.modules.loss.L1Loss'>, condition_to_update=None)[source]#
Bases:
RefinementInterface
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those with low residuals. Points are sampled uniformly in all regions where sampling is needed.
See also
Original Reference: Daw, Arka, et al. Mitigating Propagation Failures in Physics-informed Neural Networks using Retain-Resample-Release (R3) Sampling. (2023). DOI: 10.48550/arXiv.2207.02338
- Parameters:
- Raises:
ValueError – If the condition_to_update is not a string or iterable of strings.
TypeError – If the residual_loss is not a subclass of torch.nn.Module.
Example
>>> r3_callback = R3Refinement(sample_every=5)
- sample(current_points, condition_name, solver)[source]#
Sample new points based on the R3 refinement strategy.
- Parameters:
current_points – Current points in the domain.
condition_name – Name of the condition to update.
solver (PINNInterface) – The solver object.
- Returns:
New points sampled based on the R3 strategy.
- Return type: