Source code for pina._src.data.manager.graph_data_manager
"""Module for the Graph-Data Manager class."""
import torch
from torch_geometric.data import Data
from torch_geometric.data.batch import Batch
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.graph import Graph, LabelBatch
from pina._src.data.manager.batch_manager import _BatchManager
from pina._src.data.manager.data_manager_interface import _DataManagerInterface
[docs]
class _GraphDataManager(_DataManagerInterface):
"""
Data manager for graph-based data. It handles inputs stored as
:class:`Graph`, :class:`Data`, or lists / tuples of these types.
"""
def __init__(self, **kwargs):
"""
Initialization of the :class:`_GraphDataManager` class.
:param dict kwargs: The keyword arguments for the graph data manager.
"""
# Initialize keys
self.keys = list(kwargs.keys())
# Find graph-based data
self.graph_key = next(
k
for k, v in kwargs.items()
if isinstance(v, (Graph, Data, list, tuple))
)
# Find tensor data
self.keys = [
k
for k in self.keys
if k != self.graph_key
and isinstance(kwargs[k], (torch.Tensor, LabelTensor))
]
# Prepare graphs and assign tensors
self.data = self._prepare_graphs(kwargs)
def __len__(self):
"""
Return the number of samples in the graph data manager.
:return: The number of samples.
:rtype: int
"""
return len(self.data)
def __getitem__(self, idx):
"""
Return the item at the specified indices.
:param idx: The indices of the graphs to retrieve.
:type idx: int | slice | list[int] | torch.Tensor
:raises TypeError: If an index with invalid type is passed.
:return: A new :class:`_GraphDataManager` instance containing the
selected graphs.
:rtype: _GraphDataManager
"""
# Selection for integers or slices
if isinstance(idx, (int, slice)):
selected = self.data[idx]
# Selection for lists or tensors
elif isinstance(idx, (list, torch.Tensor)):
selected = [self.data[i] for i in idx]
# Raise TypeError if index type is invalid
else:
raise TypeError(f"Invalid index type: {type(idx)}")
# Ensure selected is a list
if not isinstance(selected, list):
selected = [selected]
return _GraphDataManager._init_from_graphs_list(
selected, graph_key=self.graph_key, keys=self.keys
)
def __getattr__(self, name):
"""
Provide dynamic access to stored graph and tensor data.
If ``name`` corresponds to the graph key, return the list of graph
objects. If it matches a tensor key, retrieve the corresponding
tensors from all graphs and stack them along the batch dimension.
:param str name: The name of the attribute to access.
:return: The requested graph data or stacked tensor values.
:rtype: torch.Tensor | LabelTensor | list[Graph] | list[Data]
"""
# Stack tensors from all graph if name is a tensor key
if name in self.keys:
tensors = [getattr(g, name) for g in self.data]
batch_fn = (
LabelTensor.stack
if isinstance(tensors[0], LabelTensor)
else torch.stack
)
return batch_fn(tensors)
# Otherwise, return graphs
if name == self.graph_key:
return self.data if len(self.data) > 1 else self.data[0]
return super().__getattribute__(name)
def _prepare_graphs(self, kwargs):
"""
Attach tensor data to the corresponding graph objects.
:param kwargs: The keyword arguments containing graph data and
associated tensor features.
:raises ValueError: If the number of graphs does not match the number of
samples in the tensor of features to associate.
:return: A list of graphs with the corresponding tensors assigned.
:rtype: list[Graph] | list[Data]
"""
# Get graph-based data and store in a list
graphs = kwargs.pop(self.graph_key)
if not isinstance(graphs, (list, tuple)):
graphs = [graphs]
# Iterate of items
for name, tensor in kwargs.items():
# Verify the consistency between the number of graphs and samples
if len(graphs) != tensor.shape[0]:
raise ValueError(
f"Number of graphs ({len(graphs)}) does not match "
f"number of samples for key '{name}' "
f"({kwargs[name].shape[0]})."
)
# Assign tensors to graphs
for i, g in enumerate(graphs):
setattr(g, name, tensor[i])
return graphs
[docs]
def to_batch(self):
"""
Create a batch from the current graph data manager.
:return: A new instance of :class:`_BatchManager` with batched data.
:rtype: _BatchManager
"""
# Define the batch function
batching_fn = (
LabelBatch.from_data_list
if isinstance(self.data[0], Graph)
else Batch.from_data_list
)
# Create the batch manager
batch_data = _BatchManager()
batched_graph = batching_fn(self.data)
for k in self.keys:
if k == self.graph_key:
continue
batch_data[k] = getattr(batched_graph, k)
delattr(batched_graph, k)
batch_data[self.graph_key] = batched_graph
return batch_data
[docs]
@staticmethod
def create_batch(items):
"""
Create a batch from a list of :class:`_GraphDataManager` items.
:param list[_GraphDataManager] items: A list of
:class:`_GraphDataManager` items to batch.
:return: A new instance of :class:`_BatchManager` containing the batched
data.
:rtype: _BatchManager
"""
# Return None if no items are provided
if not items:
return None
# Retrieve the first _GraphDataManager of the list and corresponding key
first = items[0]
graph_key = first.graph_key
# Initialize the batch manager
batch_data = _BatchManager()
# Define batch function
batching_fn = (
LabelBatch.from_data_list
if isinstance(first.data[0], Graph)
else Batch.from_data_list
)
# Batch over graphs
batched_graph = batching_fn([item.data[0] for item in items])
# Use a set for O(1) lookups if keys are large
keys_to_transfer = set(first.keys)
if graph_key in keys_to_transfer:
keys_to_transfer.remove(graph_key)
# Iterate over the keys of the _GraphDataManager
for k in keys_to_transfer:
# Extract values
val = getattr(batched_graph, k, None)
if val is not None:
batch_data[k] = val
delattr(batched_graph, k)
# Assign key to batch
batch_data[graph_key] = batched_graph
return batch_data
@classmethod
def _init_from_graphs_list(cls, graphs, graph_key, keys):
"""
Create a :class:`_GraphDataManager` instance directly from a list of
graph objects.
This method bypasses the standard initialization logic and is used
internally to construct new instances (e.g., subsets) from already
processed graph data.
:param list graphs: A list of graph objects.
:param str graph_key: The name of the attribute used to store the
graphs.
:param list keys: A list of tensor keys associated with the graphs.
:return: A new instance of :class:`_GraphDataManager`.
:rtype: _GraphDataManager
"""
# Create a new instance without calling __init__
obj = _GraphDataManager.__new__(_GraphDataManager)
obj.graph_key = graph_key
obj.keys = keys
obj.data = graphs
return obj