Source code for pina.condition.condition_interface
"""Module for the Condition interface."""
from abc import ABCMeta
from torch_geometric.data import Data
from ..label_tensor import LabelTensor
from ..graph import Graph
[docs]
class ConditionInterface(metaclass=ABCMeta):
"""
Abstract class which defines a common interface for all the conditions.
It defined a common interface for all the conditions.
"""
def __init__(self):
"""
Initialize the ConditionInterface object.
"""
self._problem = None
@property
def problem(self):
"""
Return the problem to which the condition is associated.
:return: Problem to which the condition is associated.
:rtype: ~pina.problem.abstract_problem.AbstractProblem
"""
return self._problem
@problem.setter
def problem(self, value):
"""
Set the problem to which the condition is associated.
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
which the condition is associated
"""
self._problem = value
@staticmethod
def _check_graph_list_consistency(data_list):
"""
Check the consistency of the list of Data/Graph objects. It performs
the following checks:
1. All elements in the list must be of the same type (either Data or
Graph).
2. All elements in the list must have the same keys.
3. The type of each tensor must be consistent across all elements in
the list.
4. If the tensor is a LabelTensor, the labels must be consistent across
all elements in the list.
:param data_list: List of Data/Graph objects to check
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
:raises ValueError: If the input types are invalid.
:raises ValueError: If all elements in the list do not have the same
keys.
:raises ValueError: If the type of each tensor is not consistent across
all elements in the list.
:raises ValueError: If the labels of the LabelTensors are not consistent
across all elements in the list.
"""
# If the data is a Graph or Data object, return (do not need to check
# anything)
if isinstance(data_list, (Graph, Data)):
return
# check all elements in the list are of the same type
if not all(isinstance(i, (Graph, Data)) for i in data_list):
raise ValueError(
"Invalid input types. "
"Please provide either Data or Graph objects."
)
data = data_list[0]
# Store the keys of the first element in the list
keys = sorted(list(data.keys()))
# Store the type of each tensor inside first element Data/Graph object
data_types = {name: tensor.__class__ for name, tensor in data.items()}
# Store the labels of each LabelTensor inside first element Data/Graph
# object
labels = {
name: tensor.labels
for name, tensor in data.items()
if isinstance(tensor, LabelTensor)
}
# Iterate over the list of Data/Graph objects
for data in data_list[1:]:
# Check if the keys of the current element are the same as the first
# element
if sorted(list(data.keys())) != keys:
raise ValueError(
"All elements in the list must have the same keys."
)
for name, tensor in data.items():
# Check if the type of each tensor inside the current element
# is the same as the first element
if tensor.__class__ is not data_types[name]:
raise ValueError(
f"Data {name} must be a {data_types[name]}, got "
f"{tensor.__class__}"
)
# If the tensor is a LabelTensor, check if the labels are the
# same as the first element
if isinstance(tensor, LabelTensor):
if tensor.labels != labels[name]:
raise ValueError(
"LabelTensor must have the same labels"
)