RadiusGraph#

class RadiusGraph(pos, radius, **kwargs)[source]#

Bases: GraphBuilder

Extends the GraphBuilder class to compute edge_index based on a radius. Each point is connected to all the points within the radius.

Instantiate the Graph class by computing the edge_index based on the radius provided.

Parameters:
  • pos (torch.Tensor | LabelTensor) – A tensor of shape (N, D) representing the positions of N points in D-dimensional space.

  • radius (float) – The radius within which points are connected.

  • kwargs (dict) – The additional keyword arguments to be passed to GraphBuilder and Graph classes.

Returns:

A Graph instance with the computed edge_index.

Return type:

Graph

static compute_radius_graph(points, radius)[source]#

Computes the edge_index based on the radius. Each point is connected to all the points within the radius.

Parameters:
  • points (torch.Tensor | LabelTensor) – A tensor of shape (N, D) representing the positions of N points in D-dimensional space.

  • radius (float) – The radius within which points are connected.

Returns:

A tensor of shape (2, E), with E number of edges, representing the edge indices of the graph.

Return type:

torch.Tensor