Source code for pina._src.data.condition_subset
"""Utilities for handling condition dataset subsets."""
from torch_geometric.data import Batch
from pina._src.core.graph import LabelBatch, Graph
[docs]
class _ConditionSubset:
"""
Wrapper around a condition dataset restricted to a subset of indices.
The class behaves similarly to :class:`torch.utils.data.Subset` and supports
cyclic indexing together with optional automatic batching.
"""
def __init__(self, condition, indices, automatic_batching):
"""
Initialization of the :class:`_ConditionSubset` class.
:param BaseCondition condition: The underlying condition.
:param list[int] indices: The list of indices identifying the subset
samples.
:param bool automatic_batching: Whether dataset items should be returned
directly or as raw indices.
"""
super().__init__()
# Initialize the class attributes
self.condition = condition
self.indices = indices
self.automatic_batching = automatic_batching
# Actual number of samples contained in the subset
self.dataset_length = len(self.indices)
# Effective iterable length used and modified during batching
self.iterable_length = self.dataset_length
def __len__(self):
"""
Return the effective iterable length of the subset.
:return: The number of accessible elements in the subset.
:rtype: int
"""
return self.iterable_length
def __getitem__(self, idx):
"""
Retrieve an element from the subset.
If the requested index exceeds the actual dataset size, cyclic indexing
is applied through modulo wrapping. When automatic batching is disabled,
the raw dataset index is returned instead of the corresponding sample.
:param int idx: The position of the element inside the subset.
:return: The dataset sample or raw dataset index depending on the
batching configuration.
:rtype: dict | int
"""
# Apply cyclic indexing if the requested index exceeds the subset length
if idx >= self.dataset_length:
idx = idx % self.dataset_length
# Fetch the corresponding dataset index from the list of indices
idx = self.indices[idx]
# Return the raw dataset index if automatic batching is disabled
if not self.automatic_batching:
return idx
return self.condition[idx]
[docs]
def get_all_data(self):
"""
Retrieve and aggregate all subset samples.
If the returned data contains a ``"data"`` field composed of graph
objects, the samples are merged into a single batched graph structure
using the appropriate batching implementation.
:return: The aggregated subset data.
:rtype: dict
"""
# Fetch the data corresponding to the subset indices
data = self.condition[self.indices]
# Data as a list of graph objects merged into a single batched graph
if "data" in data and isinstance(data["data"], list):
# Define the batching function
batch_fn = (
LabelBatch.from_data_list
if isinstance(data["data"][0], Graph)
else Batch.from_data_list
)
# Merge the list of graph objects into a single batched graph
data["data"] = batch_fn(data["data"])
data = {"input": data["data"], "target": data["data"].y}
return data