AutomaticShiftSnapshots
Module for Scaler plugin
- class AutomaticShiftSnapshots(shift_network, interp_network, interpolator, parameter_index=0, reference_index=0, barycenter_loss=0)[source]
Bases:
PluginThe plugin implements the automatic “shifting” preprocessing: exploiting a machine learning framework, it is able to detect the quantity to shift the snapshots composing the database, such that the reduction method performs better, depending on the problem at hand.
Reference: Papapicco, D., Demo, N., Girfoglio, M., Stabile, G., & Rozza, G. (2022). The Neural Network shifted-proper orthogonal decomposition: A machine learning approach for non-linear reduction of hyperbolic equations. Computer Methods in Applied Mechanics and Engineering, 392, 114687.
- Parameters:
shift_function (callable) – a user defined function that return the shifting quantity for any the snapshot, given the corresponding input parameter.
interpolator (Approximation) – the interpolator to use to evaluate the shifted snapshots on some reference space.
parameter_index (int) – in case of multi-dimensional parameter, indicate the index of the parameter component to pass to the shift function. Default is 0.
reference_index (int) – indicate the index of the snapshots within the database whose space will be used as reference space. Default is 0.
Example:
>>> from ezyrb import POD, RBF, Database, Snapshot, Parameter, Linear, ANN >>> from ezyrb import ReducedOrderModel as ROM >>> from ezyrb.plugin import AutomaticShiftSnapshots >>> interp = ANN([10, 10], torch.nn.Softplus(), 1000, frequency_print=50, lr=0.03) >>> shift = ANN([], torch.nn.LeakyReLU(), [2000, 1e-3], frequency_print=50, l2_regularization=0, lr=0.002) >>> nnspod = AutomaticShiftSnapshots(shift, interp, Linear(fill_value=0.0), barycenter_loss=10.) >>> pod = POD(rank=1) >>> rbf = RBF() >>> db = Database() >>> for param in params: >>> space, values = wave(param) >>> snap = Snapshot(values=values, space=space) >>> db.add(Parameter(param), snap) >>> rom = ROM(db, pod, rbf, plugins=[nnspod]) >>> rom.fit()
Initialize the AutomaticShiftSnapshots plugin.
- Parameters:
shift_network – Neural network for learning the shift function.
interp_network – Neural network for interpolation.
interpolator (Approximation) – Interpolator for shifted snapshots evaluation.
parameter_index (int) – Index of parameter component. Default is 0.
reference_index (int) – Index of reference snapshot. Default is 0.
barycenter_loss (float) – Weight for barycenter loss term. Default is 0.
- _abc_impl = <_abc._abc_data object>
- _train_interp_network()[source]
Train the interpolation network on the reference snapshot.
- _train_shift_network(db)[source]
Train the shift network using the database snapshots.
- Parameters:
db (Database) – The database containing snapshots.
- fit_preprocessing(rom)[source]
Execute before the fit process begins.
- Parameters:
rom (ReducedOrderModel) – The ROM instance.
- predict_postprocessing(rom)[source]
Execute after the prediction process completes.
- Parameters:
rom (ReducedOrderModel) – The ROM instance.