E(n) Equivariant Network Block#

class EnEquivariantNetworkBlock(node_feature_dim, edge_feature_dim, pos_dim, hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]

Bases: MessagePassing

Implementation of the E(n) Equivariant Graph Neural Network block. This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Satorras et al. in 2021. It serves as an inner block in a larger graph neural network architecture.

The message between two nodes connected by an edge is computed by applying a linear transformation to the sender node features and the edge features, together with the squared euclidean distance between the sender and recipient node positions, followed by a non-linear activation function. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).

The update step is performed by applying another MLP to the concatenation of the incoming messages and the node features. Here, also the node positions are updated by adding the incoming messages divided by the degree of the recipient node.

See also

Original reference Satorras, V. G., Hoogeboom, E., Welling, M. (2021). E(n) Equivariant Graph Neural Networks. In International Conference on Machine Learning. DOI: https://doi.org/10.48550/arXiv.2102.09844.

Initialization of the EnEquivariantNetworkBlock class.

Parameters:
  • node_feature_dim (int) – The dimension of the node features.

  • edge_feature_dim (int) – The dimension of the edge features.

  • pos_dim (int) – The dimension of the position features.

  • hidden_dim (int) – The dimension of the hidden features. Default is 64.

  • n_message_layers (int) – The number of layers in the message network. Default is 2.

  • n_update_layers (int) – The number of layers in the update network. Default is 2.

  • activation (torch.nn.Module) – The activation function. Default is torch.nn.SiLU.

  • aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See torch_geometric.nn.MessagePassing for more details. Default is “add”.

  • node_dim (int) – The axis along which to propagate. Default is -2.

  • flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See torch_geometric.nn.MessagePassing for more details. Default is “source_to_target”.

Raises:
forward(x, pos, edge_index, edge_attr=None)[source]

Forward pass of the block, triggering the message-passing routine.

Parameters:
Returns:

The updated node features and node positions.

Return type:

tuple(torch.Tensor, torch.Tensor)

message(x_i, x_j, pos_i, pos_j, edge_attr)[source]

Compute the message to be passed between nodes and edges.

Parameters:
Returns:

The message to be passed.

Return type:

tuple(torch.Tensor, torch.Tensor)

aggregate(inputs, index, ptr=None, dim_size=None)[source]

Aggregate the messages at the nodes during message passing.

This method receives a tuple of tensors corresponding to the messages to be aggregated. Both messages are aggregated separately according to the specified aggregation scheme.

Parameters:
  • inputs (tuple(torch.Tensor)) – Tuple containing two messages to aggregate.

  • index (torch.Tensor | LabelTensor) – The indices of target nodes for each message. This tensor specifies which node each message is aggregated into.

  • ptr (torch.Tensor | LabelTensor) – Optional tensor to specify the slices of messages for each node (used in some aggregation strategies). Default is None.

  • dim_size (int) – Optional size of the output dimension, i.e., number of nodes. Default is None.

Returns:

Tuple of aggregated tensors corresponding to (aggregated messages for position updates, aggregated messages for feature updates).

Return type:

tuple(torch.Tensor, torch.Tensor)

update(aggregated_inputs, x, pos, edge_index)[source]

Update the node features and the node coordinates with the received messages.

Parameters:
Returns:

The updated node features and node positions.

Return type:

tuple(torch.Tensor, torch.Tensor)