Source code for pina.condition.condition_interface
"""Module for the Condition interface."""
from abc import ABCMeta
from torch_geometric.data import Data
from ..label_tensor import LabelTensor
from ..graph import Graph
[docs]
class ConditionInterface(metaclass=ABCMeta):
"""
Abstract base class for PINA conditions. All specific conditions must
inherit from this interface.
Refer to :class:`pina.condition.condition.Condition` for a thorough
description of all available conditions and how to instantiate them.
"""
def __init__(self):
"""
Initialization of the :class:`ConditionInterface` class.
"""
self._problem = None
@property
def problem(self):
"""
Return the problem associated with this condition.
:return: Problem associated with this condition.
:rtype: ~pina.problem.abstract_problem.AbstractProblem
"""
return self._problem
@problem.setter
def problem(self, value):
"""
Set the problem associated with this condition.
:param pina.problem.abstract_problem.AbstractProblem value: The problem
to associate with this condition
"""
self._problem = value
@staticmethod
def _check_graph_list_consistency(data_list):
"""
Check the consistency of the list of Data | Graph objects.
The following checks are performed:
- All elements in the list must be of the same type (either
:class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
- All elements in the list must have the same keys.
- The data type of each tensor must be consistent across all elements.
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
must also be consistent across all elements.
:param data_list: The list of Data | Graph objects to check.
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
:raises ValueError: If the input types are invalid.
:raises ValueError: If all elements in the list do not have the same
keys.
:raises ValueError: If the type of each tensor is not consistent across
all elements in the list.
:raises ValueError: If the labels of the LabelTensors are not consistent
across all elements in the list.
"""
# If the data is a Graph or Data object, perform no checks
if isinstance(data_list, (Graph, Data)):
return
# Check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError(
"Invalid input. Please, provide either Data or Graph objects."
)
# Store the keys, data types and labels of the first element
data = data_list[0]
keys = sorted(list(data.keys()))
data_types = {name: tensor.__class__ for name, tensor in data.items()}
labels = {
name: tensor.labels
for name, tensor in data.items()
if isinstance(tensor, LabelTensor)
}
# Iterate over the list of Data | Graph objects
for data in data_list[1:]:
# Check that all elements in the list have the same keys
if sorted(list(data.keys())) != keys:
raise ValueError(
"All elements in the list must have the same keys."
)
# Iterate over the tensors in the current element
for name, tensor in data.items():
# Check that the type of each tensor is consistent
if tensor.__class__ is not data_types[name]:
raise ValueError(
f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}"
)
# Check that the labels of each LabelTensor are consistent
if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]:
raise ValueError(
"LabelTensor must have the same labels"
)
def __getattribute__(self, name):
"""
Get an attribute from the object.
:param str name: The name of the attribute to get.
:return: The requested attribute.
:rtype: Any
"""
to_return = super().__getattribute__(name)
if isinstance(to_return, (Graph, Data)):
to_return = [to_return]
return to_return