DatabaseSplitter
- class DatabaseSplitter(train=0.9, test=0.1, validation=0.0, predict=0.0, seed=None)[source]
Bases:
PluginPlugin for splitting the database into training, test, validation, and prediction sets.
This plugin automatically splits the database according to specified ratios before the fitting process begins.
- Parameters:
train (float) – Ratio or number of samples for training set. Default is 0.9.
test (float) – Ratio or number of samples for test set. Default is 0.1.
validation (float) – Ratio or number of samples for validation set. Default is 0.0.
predict (float) – Ratio or number of samples for prediction set. Default is 0.0.
seed (int) – Random seed for reproducibility. Default is None.
- Example:
>>> from ezyrb import ReducedOrderModel as ROM >>> from ezyrb import POD, RBF, Database >>> from ezyrb.plugin import DatabaseSplitter >>> import numpy as np >>> params = np.random.rand(100, 2) >>> snapshots = np.random.rand(100, 50) >>> db = Database(params, snapshots) >>> pod = POD(rank=5) >>> rbf = RBF() >>> splitter = DatabaseSplitter(train=0.7, test=0.2, validation=0.1) >>> rom = ROM(db, pod, rbf, plugins=[splitter]) >>> rom.fit()
Initialize the DatabaseSplitter plugin.
- Parameters:
- _abc_impl = <_abc._abc_data object>
- fit_preprocessing(rom)[source]
Split the database before fitting begins.
- Parameters:
rom (ReducedOrderModel) – The ROM instance.
- class DatabaseDictionarySplitter(train_key=None, test_key=None, validation_key=None, predict_key=None)[source]
Bases:
PluginThis plugin class is used to define the train, test, validation and predict databases when the databases are already split: train, test, validation and predict are already database objects stored in a dictionary. Given the desired keys of the dictionary as input, the plugin will assign the corresponding database objects to the train, test, validation and predict attributes of the ROM.
- Example:
>>> from ezyrb import ReducedOrderModel as ROM >>> from ezyrb import POD, RBF, Database >>> from ezyrb.plugin import DatabaseDictionarySplitter >>> db_dict = { ... 'train': Database(train_params, train_snaps), ... 'test': Database(test_params, test_snaps) ... } >>> pod = POD(rank=5) >>> rbf = RBF() >>> splitter = DatabaseDictionarySplitter(train_key='train', test_key='test') >>> rom = ROM(db_dict['train'], pod, rbf, plugins=[splitter]) >>> rom.fit()
Initialize the DatabaseDictionarySplitter plugin.
- Parameters:
- _abc_impl = <_abc._abc_data object>
- fit_preprocessing(rom)[source]
Assign the database splits from the dictionary before fitting.
- Parameters:
rom (ReducedOrderModel) – The ROM instance.
- Raises:
ValueError – If the database is not a dictionary.