Source code for ezyrb.database

"""Module for the snapshots database collected during the Offline stage."""

import numpy as np

from .parameter import Parameter
from .snapshot import Snapshot

[docs] class Database(): """ Database class for storing parameter-snapshot pairs. :param array_like parameters: the input parameters :param array_like snapshots: the input snapshots :param Scale scaler_parameters: the scaler for the parameters. Default is None meaning no scaling. :param Scale scaler_snapshots: the scaler for the snapshots. Default is None meaning no scaling. :param array_like space: the input spatial data :Example: >>> import numpy as np >>> from ezyrb import Database, Parameter, Snapshot >>> params = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) >>> snapshots = np.random.rand(3, 100) >>> db = Database(params, snapshots) >>> print(len(db)) 3 >>> print(db.parameters_matrix.shape) (3, 2) >>> print(db.snapshots_matrix.shape) (3, 100) """ def __init__(self, parameters=None, snapshots=None, space=None): self._pairs = [] if parameters is None and snapshots is None: return if parameters is None: parameters = [None] * len(snapshots) elif snapshots is None: snapshots = [None] * len(parameters) if len(parameters) != len(snapshots): raise ValueError('parameters and snapshots must have the same length') for param, snap in zip(parameters, snapshots): param = Parameter(param) if isinstance(space, dict): snap_space = space.get(tuple(param.values), None) # print('snap_space', snap_space) else: snap_space = space snap = Snapshot(snap, space=snap_space) self.add(param, snap) # TODO: eventually improve the `space` assignment in the snapshots, # snapshots can have different space coordinates @property def parameters_matrix(self): """ The matrix containing the input parameters (by row). :rtype: numpy.ndarray """ return np.asarray([pair[0].values for pair in self._pairs]) @property def snapshots_matrix(self): """ The matrix containing the snapshots (by row). :rtype: numpy.ndarray """ return np.asarray([pair[1].flattened for pair in self._pairs]) def __getitem__(self, val): """ This method returns a new Database with the selected parameters and snapshots. .. warning:: The new parameters and snapshots are a view of the original Database. """ if isinstance(val, np.ndarray): view = Database() for p, s in np.asarray(self._pairs)[val]: view.add(p, s) elif isinstance(val, (int, slice)): view = Database() view._pairs = self._pairs[val] return view def __len__(self): """ This method returns the number of snapshots. :rtype: int """ return len(self._pairs) def __str__(self): """ Print minimal info about the Database """ s = 'Database with {} snapshots and {} parameters'.format( self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1]) return s
[docs] def add(self, parameter, snapshot): """ Add (by row) new sets of snapshots and parameters to the original database. :param Parameter parameter: the parameter to add. :param Snapshot snapshot: the snapshot to add. """ if not isinstance(parameter, Parameter): raise ValueError if not isinstance(snapshot, Snapshot): raise ValueError self._pairs.append((parameter, snapshot)) return self
[docs] def split(self, chunks, seed=None): """ >>> db = Database(...) >>> train, test = db.split([0.8, 0.2]) # ratio >>> train, test = db.split([80, 20]) # n snapshots """ if seed is not None: np.random.seed(seed) if all(isinstance(n, int) for n in chunks): if sum(chunks) != len(self): raise ValueError('chunk elements are inconsistent') ids = [ j for j, chunk in enumerate(chunks) for i in range(chunk) ] np.random.shuffle(ids) elif all(isinstance(n, float) for n in chunks): if not np.isclose(sum(chunks), 1.): raise ValueError('chunk elements are inconsistent') cum_chunks = np.cumsum(chunks) cum_chunks = np.insert(cum_chunks, 0, 0.0) ids = np.ones(len(self)) * -1. tmp = np.random.uniform(0, 1, size=len(self)) for i in range(len(cum_chunks)-1): is_between = np.logical_and( tmp >= cum_chunks[i], tmp < cum_chunks[i+1]) ids[is_between] = i else: ValueError new_database = [Database() for _ in range(len(chunks))] for i, chunk in enumerate(chunks): chunk_ids = np.array(ids) == i for p, s in np.asarray(self._pairs)[chunk_ids]: new_database[i].add(p, s) return new_database
[docs] def get_snapshot_space(self, index): """ Get the space coordinates of a snapshot by its index. :param int index: The index of the snapshot. :return: The space coordinates of the snapshot. :rtype: numpy.ndarray """ if index < 0 or index >= len(self._pairs): raise IndexError("Snapshot index out of range.") return self._pairs[index][1].space