"""Module to build Graph objects and perform operations on them."""
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected
from .label_tensor import LabelTensor
from .utils import check_consistency, is_function
[docs]
class Graph(Data):
"""
Extends :class:`~torch_geometric.data.Data` class to include additional
checks and functionlities.
"""
def __new__(
cls,
**kwargs,
):
"""
Create a new instance of the :class:`~pina.graph.Graph` class by
checking the consistency of the input data and storing the attributes.
:param dict kwargs: Parameters used to initialize the
:class:`~pina.graph.Graph` object.
:return: A new instance of the :class:`~pina.graph.Graph` class.
:rtype: Graph
"""
# create class instance
instance = Data.__new__(cls)
# check the consistency of types defined in __init__, the others are not
# checked (as in pyg Data object)
instance._check_type_consistency(**kwargs)
return instance
def __init__(
self,
x=None,
edge_index=None,
pos=None,
edge_attr=None,
undirected=False,
**kwargs,
):
"""
Initialize the object by setting the node features, edge index,
edge attributes, and positions. The edge index is preprocessed to make
the graph undirected if required. For more details, see the
:meth:`torch_geometric.data.Data`
:param x: Optional tensor of node features ``(N, F)`` where ``F`` is the
number of features per node.
:type x: torch.Tensor, LabelTensor
:param torch.Tensor edge_index: A tensor of shape ``(2, E)``
representing the indices of the graph's edges.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param edge_attr: Optional tensor of edge_featured ``(E, F')`` where
``F'`` is the number of edge features
:type edge_attr: torch.Tensor | LabelTensor
:param bool undirected: Whether to make the graph undirected
:param dict kwargs: Additional keyword arguments passed to the
:class:`~torch_geometric.data.Data` class constructor.
"""
# preprocessing
self._preprocess_edge_index(edge_index, undirected)
# calling init
super().__init__(
x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs
)
[docs]
def _check_type_consistency(self, **kwargs):
"""
Check the consistency of the types of the input data.
:param dict kwargs: Attributes to be checked for consistency.
"""
# default types, specified in cls.__new__, by default they are Nont
# if specified in **kwargs they get override
x, pos, edge_index, edge_attr = None, None, None, None
if "pos" in kwargs:
pos = kwargs["pos"]
self._check_pos_consistency(pos)
if "edge_index" in kwargs:
edge_index = kwargs["edge_index"]
self._check_edge_index_consistency(edge_index)
if "x" in kwargs:
x = kwargs["x"]
self._check_x_consistency(x, pos)
if "edge_attr" in kwargs:
edge_attr = kwargs["edge_attr"]
self._check_edge_attr_consistency(edge_attr, edge_index)
if "undirected" in kwargs:
undirected = kwargs["undirected"]
check_consistency(undirected, bool)
[docs]
@staticmethod
def _check_pos_consistency(pos):
"""
Check if the position tensor is consistent.
:param torch.Tensor pos: The position tensor.
:raises ValueError: If the position tensor is not consistent.
"""
if pos is not None:
check_consistency(pos, (torch.Tensor, LabelTensor))
if pos.ndim != 2:
raise ValueError("pos must be a 2D tensor.")
[docs]
@staticmethod
def _check_edge_index_consistency(edge_index):
"""
Check if the edge index is consistent.
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge index tensor is not consistent.
"""
check_consistency(edge_index, (torch.Tensor, LabelTensor))
if edge_index.ndim != 2:
raise ValueError("edge_index must be a 2D tensor.")
if edge_index.size(0) != 2:
raise ValueError("edge_index must have shape [2, num_edges].")
[docs]
@staticmethod
def _check_edge_attr_consistency(edge_attr, edge_index):
"""
Check if the edge attribute tensor is consistent in type and shape
with the edge index.
:param edge_attr: The edge attribute tensor.
:type edge_attr: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge index tensor.
:raises ValueError: If the edge attribute tensor is not consistent.
"""
if edge_attr is not None:
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
if edge_attr.ndim != 2:
raise ValueError("edge_attr must be a 2D tensor.")
if edge_attr.size(0) != edge_index.size(1):
raise ValueError(
"edge_attr must have shape "
"[num_edges, num_edge_features], expected "
f"num_edges {edge_index.size(1)} "
f"got {edge_attr.size(0)}."
)
[docs]
@staticmethod
def _check_x_consistency(x, pos=None):
"""
Check if the input tensor x is consistent with the position tensor
`pos`.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param pos: The position tensor.
:type pos: torch.Tensor | LabelTensor
:raises ValueError: If the input tensor is not consistent.
"""
if x is not None:
check_consistency(x, (torch.Tensor, LabelTensor))
if x.ndim != 2:
raise ValueError("x must be a 2D tensor.")
if pos is not None:
if x.size(0) != pos.size(0):
raise ValueError("Inconsistent number of nodes.")
[docs]
@staticmethod
def _preprocess_edge_index(edge_index, undirected):
"""
Preprocess the edge index to make the graph undirected (if required).
:param torch.Tensor edge_index: The edge index.
:param bool undirected: Whether the graph is undirected.
:return: The preprocessed edge index.
:rtype: torch.Tensor
"""
if undirected:
edge_index = to_undirected(edge_index)
return edge_index
[docs]
class GraphBuilder:
"""
A class that allows an easy definition of :class:`Graph` instances.
"""
def __new__(
cls,
pos,
edge_index,
x=None,
edge_attr=False,
custom_edge_func=None,
**kwargs,
):
"""
Compute the edge attributes and create a new instance of the
:class:`~pina.graph.Graph` class.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor or LabelTensor
:param edge_index: A tensor of shape ``(2, E)`` representing the indices
of the graph's edges.
:type edge_index: torch.Tensor
:param x: Optional tensor of node features of shape ``(N, F)``, where
``F`` is the number of features per node.
:type x: torch.Tensor | LabelTensor, optional
:param edge_attr: Optional tensor of edge attributes of shape ``(E, F)``
, where ``F`` is the number of features per edge.
:type edge_attr: torch.Tensor, optional
:param custom_edge_func: A custom function to compute edge attributes.
If provided, overrides ``edge_attr``.
:type custom_edge_func: Callable, optional
:param kwargs: Additional keyword arguments passed to the
:class:`~pina.graph.Graph` class constructor.
:return: A :class:`~pina.graph.Graph` instance constructed using the
provided information.
:rtype: Graph
"""
edge_attr = cls._create_edge_attr(
pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr
)
return Graph(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
pos=pos,
**kwargs,
)
[docs]
@staticmethod
def _create_edge_attr(pos, edge_index, edge_attr, func):
"""
Create the edge attributes based on the input parameters.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
:param bool edge_attr: Whether to compute the edge attributes.
:param Callable func: Function to compute the edge attributes.
:raises ValueError: If ``func`` is not a function.
:return: The edge attributes.
:rtype: torch.Tensor | LabelTensor | None
"""
check_consistency(edge_attr, bool)
if edge_attr:
if is_function(func):
return func(pos, edge_index)
raise ValueError("custom_edge_func must be a function.")
return None
[docs]
@staticmethod
def _build_edge_attr(pos, edge_index):
"""
Default function to compute the edge attributes.
:param pos: Positions of the points.
:type pos: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: Edge indices.
:return: The edge attributes.
:rtype: torch.Tensor
"""
return (
(pos[edge_index[0]] - pos[edge_index[1]])
.abs()
.as_subclass(torch.Tensor)
)
[docs]
class RadiusGraph(GraphBuilder):
"""
Extends the :class:`~pina.graph.GraphBuilder` class to compute
``edge_index`` based on a radius. Each point is connected to all the points
within the radius.
"""
def __new__(cls, pos, radius, **kwargs):
"""
Instantiate the :class:`~pina.graph.Graph` class by computing the
``edge_index`` based on the radius provided.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected.
:param dict kwargs: The additional keyword arguments to be passed to
:class:`GraphBuilder` and :class:`Graph` classes.
:return: A :class:`~pina.graph.Graph` instance with the computed
``edge_index``.
:rtype: Graph
"""
edge_index = cls.compute_radius_graph(pos, radius)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
[docs]
@staticmethod
def compute_radius_graph(points, radius):
"""
Computes the ``edge_index`` based on the radius. Each point is connected
to all the points within the radius.
:param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor
:param float radius: The radius within which points are connected.
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
representing the edge indices of the graph.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
return (
torch.nonzero(dist <= radius, as_tuple=False)
.t()
.as_subclass(torch.Tensor)
)
[docs]
class KNNGraph(GraphBuilder):
"""
Extends the :class:`~pina.graph.GraphBuilder` class to compute
``edge_index`` based on a K-nearest neighbors algorithm.
"""
def __new__(cls, pos, neighbours, **kwargs):
"""
Instantiate the :class:`~pina.graph.Graph` class by computing the
``edge_index`` based on the K-nearest neighbors algorithm.
:param pos: A tensor of shape ``(N, D)`` representing the positions of
``N`` points in ``D``-dimensional space.
:type pos: torch.Tensor | LabelTensor
:param int neighbours: The number of nearest neighbors to consider when
building the graph.
:param dict kwargs: The additional keyword arguments to be passed to
:class:`GraphBuilder` and :class:`Graph` classes.
:return: A :class:`~pina.graph.Graph` instance with the computed
``edge_index``.
:rtype: Graph
"""
edge_index = cls.compute_knn_graph(pos, neighbours)
return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs)
[docs]
@staticmethod
def compute_knn_graph(points, neighbours):
"""
Computes the ``edge_index`` based on the K-nearest neighbors algorithm.
:param points: A tensor of shape ``(N, D)`` representing the positions
of ``N`` points in ``D``-dimensional space.
:type points: torch.Tensor | LabelTensor
:param int neighbours: The number of nearest neighbors to consider when
building the graph.
:return: A tensor of shape ``(2, E)``, with ``E`` number of edges,
representing the edge indices of the graph.
:rtype: torch.Tensor
"""
dist = torch.cdist(points, points, p=2)
knn_indices = torch.topk(dist, k=neighbours + 1, largest=False).indices[
:, 1:
]
row = torch.arange(points.size(0)).repeat_interleave(neighbours)
col = knn_indices.flatten()
return torch.stack([row, col], dim=0).as_subclass(torch.Tensor)
[docs]
class LabelBatch(Batch):
"""
Extends the :class:`~torch_geometric.data.Batch` class to include
:class:`~pina.label_tensor.LabelTensor` objects.
"""
[docs]
@classmethod
def from_data_list(cls, data_list):
"""
Create a Batch object from a list of :class:`~torch_geometric.data.Data`
or :class:`~pina.graph.Graph` objects.
:param data_list: List of :class:`~torch_geometric.data.Data` or
:class:`~pina.graph.Graph` objects.
:type data_list: list[Data] | list[Graph]
:return: A :class:`~torch_geometric.data.Batch` object containing
the input data.
:rtype: :class:`~torch_geometric.data.Batch`
"""
# Store the labels of Data/Graph objects (all data have the same labels)
# If the data do not contain labels, labels is an empty dictionary,
# therefore the labels are not stored
labels = {
k: v.labels
for k, v in data_list[0].items()
if isinstance(v, LabelTensor)
}
# Create a Batch object from the list of Data objects
batch = super().from_data_list(data_list)
# Put the labels back in the Batch object
for k, v in labels.items():
batch[k].labels = v
return batch