Adaptive Refinments callbacks#
- class R3Refinement(sample_every)[source]#
Bases:
Callback
PINA Implementation of an R3 Refinement Callback.
This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. The algorithm incrementally accumulates collocation points in regions of high PDE residuals, and releases those with low residuals. Points are sampled uniformly in all regions where sampling is needed.
See also
Original Reference: Daw, Arka, et al. Mitigating Propagation Failures in Physics-informed Neural Networks using Retain-Resample-Release (R3) Sampling. (2023). DOI: 10.48550/arXiv.2207.02338
- Parameters:
sample_every (int) – Frequency for sampling.
- Raises:
ValueError – If sample_every is not an integer.
Example
>>> r3_callback = R3Refinement(sample_every=5)
- on_train_start(trainer, _)[source]#
Callback function called at the start of training.
This method extracts the locations for sampling from the problem conditions and calculates the total population.
- Parameters:
trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.
_ – Placeholder argument (not used).
- Returns:
None
- Return type:
None
- on_train_epoch_end(trainer, __)[source]#
Callback function called at the end of each training epoch.
This method triggers the R3 routine for refinement if the current epoch is a multiple of _sample_every.
- Parameters:
trainer (pytorch_lightning.Trainer) – The trainer object managing the training process.
__ – Placeholder argument (not used).
- Returns:
None
- Return type:
None