"""Module for the snapshots database collected during the Offline stage."""
import numpy as np
from .parameter import Parameter
from .snapshot import Snapshot
[docs]
class Database():
"""
Database class for storing parameter-snapshot pairs.
:param array_like parameters: the input parameters
:param array_like snapshots: the input snapshots
:param Scale scaler_parameters: the scaler for the parameters. Default
is None meaning no scaling.
:param Scale scaler_snapshots: the scaler for the snapshots. Default is
None meaning no scaling.
:param array_like space: the input spatial data
:Example:
>>> import numpy as np
>>> from ezyrb import Database, Parameter, Snapshot
>>> params = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> snapshots = np.random.rand(3, 100)
>>> db = Database(params, snapshots)
>>> print(len(db))
3
>>> print(db.parameters_matrix.shape)
(3, 2)
>>> print(db.snapshots_matrix.shape)
(3, 100)
"""
def __init__(self, parameters=None, snapshots=None, space=None):
self._pairs = []
if parameters is None and snapshots is None:
return
if parameters is None:
parameters = [None] * len(snapshots)
elif snapshots is None:
snapshots = [None] * len(parameters)
if len(parameters) != len(snapshots):
raise ValueError('parameters and snapshots must have the same length')
for param, snap in zip(parameters, snapshots):
param = Parameter(param)
if isinstance(space, dict):
snap_space = space.get(tuple(param.values), None)
# print('snap_space', snap_space)
else:
snap_space = space
snap = Snapshot(snap, space=snap_space)
self.add(param, snap)
# TODO: eventually improve the `space` assignment in the snapshots,
# snapshots can have different space coordinates
@property
def parameters_matrix(self):
"""
The matrix containing the input parameters (by row).
:rtype: numpy.ndarray
"""
return np.asarray([pair[0].values for pair in self._pairs])
@property
def snapshots_matrix(self):
"""
The matrix containing the snapshots (by row).
:rtype: numpy.ndarray
"""
return np.asarray([pair[1].flattened for pair in self._pairs])
def __getitem__(self, val):
"""
This method returns a new Database with the selected parameters and
snapshots.
.. warning:: The new parameters and snapshots are a view of the
original Database.
"""
if isinstance(val, np.ndarray):
view = Database()
for p, s in np.asarray(self._pairs)[val]:
view.add(p, s)
elif isinstance(val, (int, slice)):
view = Database()
view._pairs = self._pairs[val]
return view
def __len__(self):
"""
This method returns the number of snapshots.
:rtype: int
"""
return len(self._pairs)
def __str__(self):
""" Print minimal info about the Database """
s = 'Database with {} snapshots and {} parameters'.format(
self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1])
return s
[docs]
def add(self, parameter, snapshot):
"""
Add (by row) new sets of snapshots and parameters to the original
database.
:param Parameter parameter: the parameter to add.
:param Snapshot snapshot: the snapshot to add.
"""
if not isinstance(parameter, Parameter):
raise ValueError
if not isinstance(snapshot, Snapshot):
raise ValueError
self._pairs.append((parameter, snapshot))
return self
[docs]
def split(self, chunks, seed=None):
"""
>>> db = Database(...)
>>> train, test = db.split([0.8, 0.2]) # ratio
>>> train, test = db.split([80, 20]) # n snapshots
"""
if seed is not None:
np.random.seed(seed)
if all(isinstance(n, int) for n in chunks):
if sum(chunks) != len(self):
raise ValueError('chunk elements are inconsistent')
ids = [
j for j, chunk in enumerate(chunks)
for i in range(chunk)
]
np.random.shuffle(ids)
elif all(isinstance(n, float) for n in chunks):
if not np.isclose(sum(chunks), 1.):
raise ValueError('chunk elements are inconsistent')
cum_chunks = np.cumsum(chunks)
cum_chunks = np.insert(cum_chunks, 0, 0.0)
ids = np.ones(len(self)) * -1.
tmp = np.random.uniform(0, 1, size=len(self))
for i in range(len(cum_chunks)-1):
is_between = np.logical_and(
tmp >= cum_chunks[i], tmp < cum_chunks[i+1])
ids[is_between] = i
else:
ValueError
new_database = [Database() for _ in range(len(chunks))]
for i, chunk in enumerate(chunks):
chunk_ids = np.array(ids) == i
for p, s in np.asarray(self._pairs)[chunk_ids]:
new_database[i].add(p, s)
return new_database
[docs]
def get_snapshot_space(self, index):
"""
Get the space coordinates of a snapshot by its index.
:param int index: The index of the snapshot.
:return: The space coordinates of the snapshot.
:rtype: numpy.ndarray
"""
if index < 0 or index >= len(self._pairs):
raise IndexError("Snapshot index out of range.")
return self._pairs[index][1].space