GraphBuilder#
- class GraphBuilder(pos, edge_index, x=None, edge_attr=False, custom_edge_func=None, **kwargs)[source]#
Bases:
object
A class that allows an easy definition of
Graph
instances.Compute the edge attributes and create a new instance of the
Graph
class.- Parameters:
pos (torch.Tensor or LabelTensor) – A tensor of shape
(N, D)
representing the positions ofN
points inD
-dimensional space.edge_index (torch.Tensor) – A tensor of shape
(2, E)
representing the indices of the graph’s edges.x (torch.Tensor | LabelTensor, optional) – Optional tensor of node features of shape
(N, F)
, whereF
is the number of features per node.edge_attr (torch.Tensor, optional) – Optional tensor of edge attributes of shape
(E, F)
, whereF
is the number of features per edge.custom_edge_func (Callable, optional) – A custom function to compute edge attributes. If provided, overrides
edge_attr
.kwargs – Additional keyword arguments passed to the
Graph
class constructor.
- Returns:
A
Graph
instance constructed using the provided information.- Return type:
- static _create_edge_attr(pos, edge_index, edge_attr, func)[source]#
Create the edge attributes based on the input parameters.
- Parameters:
pos (torch.Tensor | LabelTensor) – Positions of the points.
edge_index (torch.Tensor) – Edge indices.
edge_attr (bool) – Whether to compute the edge attributes.
func (Callable) – Function to compute the edge attributes.
- Raises:
ValueError – If
func
is not a function.- Returns:
The edge attributes.
- Return type:
torch.Tensor | LabelTensor | None
- static _build_edge_attr(pos, edge_index)[source]#
Default function to compute the edge attributes.
- Parameters:
pos (torch.Tensor | LabelTensor) – Positions of the points.
edge_index (torch.Tensor) – Edge indices.
- Returns:
The edge attributes.
- Return type: