Graph#
- class Graph(**kwargs)[source]#
Bases:
Data
Extends
Data
class to include additional checks and functionlities.Initialize the object by setting the node features, edge index, edge attributes, and positions. The edge index is preprocessed to make the graph undirected if required. For more details, see the
torch_geometric.data.Data()
- Parameters:
x (torch.Tensor, LabelTensor) – Optional tensor of node features
(N, F)
whereF
is the number of features per node.edge_index (torch.Tensor) – A tensor of shape
(2, E)
representing the indices of the graph’s edges.pos (torch.Tensor | LabelTensor) – A tensor of shape
(N, D)
representing the positions ofN
points inD
-dimensional space.edge_attr (torch.Tensor | LabelTensor) – Optional tensor of edge_featured
(E, F')
whereF'
is the number of edge featuresundirected (bool) – Whether to make the graph undirected
kwargs (dict) – Additional keyword arguments passed to the
Data
class constructor.
- _check_type_consistency(**kwargs)[source]#
Check the consistency of the types of the input data.
- Parameters:
kwargs (dict) – Attributes to be checked for consistency.
- static _check_pos_consistency(pos)[source]#
Check if the position tensor is consistent. :param torch.Tensor pos: The position tensor. :raises ValueError: If the position tensor is not consistent.
- static _check_edge_index_consistency(edge_index)[source]#
Check if the edge index is consistent.
- Parameters:
edge_index (torch.Tensor) – The edge index tensor.
- Raises:
ValueError – If the edge index tensor is not consistent.
- static _check_edge_attr_consistency(edge_attr, edge_index)[source]#
Check if the edge attribute tensor is consistent in type and shape with the edge index.
- Parameters:
edge_attr (torch.Tensor | LabelTensor) – The edge attribute tensor.
edge_index (torch.Tensor) – The edge index tensor.
- Raises:
ValueError – If the edge attribute tensor is not consistent.
- static _check_x_consistency(x, pos=None)[source]#
Check if the input tensor x is consistent with the position tensor
pos
.- Parameters:
x (torch.Tensor | LabelTensor) – The input tensor.
pos (torch.Tensor | LabelTensor) – The position tensor.
- Raises:
ValueError – If the input tensor is not consistent.
- static _preprocess_edge_index(edge_index, undirected)[source]#
Preprocess the edge index to make the graph undirected (if required).
- Parameters:
edge_index (torch.Tensor) – The edge index.
undirected (bool) – Whether the graph is undirected.
- Returns:
The preprocessed edge index.
- Return type: