Source code for pina._src.data.manager.tensor_data_manager
"""Module for the Tensor-Data Manager class."""
import torch
from pina._src.core.label_tensor import LabelTensor
from pina._src.data.manager.batch_manager import _BatchManager
from pina._src.data.manager.data_manager_interface import _DataManagerInterface
[docs]
class _TensorDataManager(_DataManagerInterface):
"""
Data manager for tensor-based data. It handles inputs stored as
:class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`.
"""
def __init__(self, **kwargs):
"""
Initialization of the :class:`_TensorDataManager` class.
:param dict kwargs: The keyword arguments for the tensor data manager.
"""
self.keys = list(kwargs.keys())
self.data = kwargs
# Set attributes from kwargs
for k, v in kwargs.items():
setattr(self, k, v)
def __len__(self):
"""
Return the number of samples in the tensor data manager.
:return: The number of samples.
:rtype: int
"""
return self.data[self.keys[0]].shape[0]
def __getitem__(self, idx):
"""
Return the item at the specified indices.
:param idx: The indices of the data point to retrieve.
:type idx: int | slice | list[int] | torch.Tensor
:return: A new :class:`_TensorDataManager` instance containing the
selected data items.
:rtype: _TensorDataManager
"""
# Get data at selected indices
new_data = {
k: (self.data[k][idx] if k in self.keys else self.data[k])
for k in self.keys
}
return _TensorDataManager(**new_data)
[docs]
def to_batch(self):
"""
Create a batch from the current tensor data manager.
:return: A new instance of :class:`_BatchManager` with batched data.
:rtype: _BatchManager
"""
# Create the batch manager
batch_data = _BatchManager()
for k in self.keys:
batch_data[k] = self.data[k]
return batch_data
[docs]
@staticmethod
def create_batch(items):
"""
Create a batch from a list of :class:`_TensorDataManager` items.
:param list[_TensorDataManager] items: A list of
:class:`_TensorDataManager` items to batch.
:return: A new instance of :class:`_BatchManager` containing the batched
data.
:rtype: _BatchManager
"""
# Return None if no items are provided
if not items:
return None
# Retrieve the first _TensorDataManager of the list
first = items[0]
# Initialize the batch manager
batch_data = _BatchManager()
# Iterate over the keys of the _TensorDataManager
for k in first.keys:
# Extract values and a sample used to determine the batch function
vals = [it.data[k] for it in items]
sample = vals[0]
# Define the batch function based on the data type
if isinstance(sample, (torch.Tensor, LabelTensor)):
batch_fn = (
LabelTensor.stack
if isinstance(sample, LabelTensor)
else torch.stack
)
batch_data[k] = batch_fn(vals)
# If no tensor is provided, just take the first value
else:
batch_data[k] = sample
return batch_data