Source code for ezyrb.plugin.scaler

""" Module for Scaler plugin """

from .plugin import Plugin


[docs] class DatabaseScaler(Plugin): """ The plugin to rescale the database of the reduced order model. It uses a user defined `scaler`, which has to have implemented the `fit`, `transform` and `inverse_trasform` methods (i.e. `sklearn` interface), to rescale the parameters and/or the snapshots. It can be applied at the full order (`mode='full'`), at the reduced one (`mode='reduced'`) or both of them (`mode='both'`). :param obj scaler: a generic object which has to have implemented the `fit`, `transform` and `inverse_trasform` methods (i.e. `sklearn` interface). :param {'full', 'reduced'} mode: define if the rescaling has to be applied at the full order ('full') or at the reduced one ('reduced'). :param {'parameters', 'snapshots'} params: define if the rescaling has to be applied to the parameters or to the snapshots. :Example: >>> from ezyrb import ReducedOrderModel as ROM >>> from ezyrb import POD, RBF, Database >>> from ezyrb.plugin import DatabaseScaler >>> from sklearn.preprocessing import StandardScaler >>> pod = POD(rank=10) >>> rbf = RBF() >>> db = Database(params, snapshots) >>> scaler = DatabaseScaler(StandardScaler(), 'full', 'snapshots') >>> rom = ROM(db, pod, rbf, plugins=[scaler]) >>> rom.fit() """ def __init__(self, scaler, mode, target) -> None: """ Initialize the DatabaseScaler plugin. :param scaler: Scaler object with fit, transform, and inverse_transform methods. :param str mode: 'full' or 'reduced' - where to apply the scaling. :param str target: 'parameters' or 'snapshots' - what to scale. """ super().__init__() self.scaler = scaler self.mode = mode self.target = target @property def target(self): """ Get the type of scaling. See class documentation for more info. rtype: str """ return self._target @target.setter def target(self, new_target): if new_target not in ['snapshots', 'parameters']: raise ValueError self._target = new_target @property def mode(self): """ Get the type of scaling. See class documentation for more info. rtype: str """ return self._mode @mode.setter def mode(self, new_mode): if new_mode not in ['full', 'reduced']: raise ValueError self._mode = new_mode
[docs] def _select_matrix(self, db): """ Helper function to select the proper matrix to rescale. :param Database db: The database object. :return: The selected matrix (parameters or snapshots). """ return getattr(db, f'{self.target}_matrix')
[docs] def rom_preprocessing(self, rom): """ Apply scaling to the reduced database before ROM processing. :param ReducedOrderModel rom: The ROM instance. """ if self.mode != 'reduced': return db = rom._reduced_database self.scaler.fit(self._select_matrix(db)) if self.target == 'parameters': new_db = type(db)( self.scaler.transform(self._select_matrix(db)), db.snapshots_matrix ) else: new_db = type(db)( db.parameters_matrix, self.scaler.transform(self._select_matrix(db)), ) rom._reduced_database = new_db
[docs] def fom_preprocessing(self, rom): if self.mode != 'full': return db = rom._full_database self.scaler.fit(self._select_matrix(db)) if self.target == 'parameters': new_db = type(db)( self.scaler.transform(self._select_matrix(db)), db.snapshots_matrix ) else: new_db = type(db)( db.parameters_matrix, self.scaler.transform(self._select_matrix(db)), ) rom._full_database = new_db
[docs] def fom_postprocessing(self, rom): if self.mode != 'full': return db = rom._full_database if self.target == 'parameters': new_db = type(db)( self.scaler.inverse_transform(self._select_matrix(db)), db.snapshots_matrix ) else: new_db = type(db)( db.parameters_matrix, self.scaler.inverse_transform(self._select_matrix(db)), ) rom._full_database = new_db
[docs] def rom_postprocessing(self, rom): if self.mode != 'reduced': return db = rom._reduced_database if self.target == 'parameters': new_db = type(db)( self.scaler.inverse_transform(self._select_matrix(db)), db.snapshots_matrix ) else: new_db = type(db)( db.parameters_matrix, self.scaler.inverse_transform(self._select_matrix(db)), ) rom._reduced_database = new_db