E(n) Equivariant Network Block#
- class EnEquivariantNetworkBlock(node_feature_dim, edge_feature_dim, pos_dim, hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]
Bases:
MessagePassing
Implementation of the E(n) Equivariant Graph Neural Network block. This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Satorras et al. in 2021. It serves as an inner block in a larger graph neural network architecture.
The message between two nodes connected by an edge is computed by applying a linear transformation to the sender node features and the edge features, together with the squared euclidean distance between the sender and recipient node positions, followed by a non-linear activation function. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).
The update step is performed by applying another MLP to the concatenation of the incoming messages and the node features. Here, also the node positions are updated by adding the incoming messages divided by the degree of the recipient node.
See also
Original reference Satorras, V. G., Hoogeboom, E., Welling, M. (2021). E(n) Equivariant Graph Neural Networks. In International Conference on Machine Learning. DOI: https://doi.org/10.48550/arXiv.2102.09844.
Initialization of the
EnEquivariantNetworkBlock
class.- Parameters:
node_feature_dim (int) – The dimension of the node features.
edge_feature_dim (int) – The dimension of the edge features.
pos_dim (int) – The dimension of the position features.
hidden_dim (int) – The dimension of the hidden features. Default is 64.
n_message_layers (int) – The number of layers in the message network. Default is 2.
n_update_layers (int) – The number of layers in the update network. Default is 2.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.SiLU
.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassing
for more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassing
for more details. Default is “source_to_target”.
- Raises:
AssertionError – If
node_feature_dim
is not a positive integer.AssertionError – If
edge_feature_dim
is a negative integer.AssertionError – If
pos_dim
is not a positive integer.AssertionError – If
hidden_dim
is not a positive integer.AssertionError – If
n_message_layers
is not a positive integer.AssertionError – If
n_update_layers
is not a positive integer.
- forward(x, pos, edge_index, edge_attr=None)[source]
Forward pass of the block, triggering the message-passing routine.
- Parameters:
x (torch.Tensor | LabelTensor) – The node features.
pos (torch.Tensor | LabelTensor) – The euclidean coordinates of the nodes.
edge_index (torch.Tensor) – The edge indices.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes. Default is None.
- Returns:
The updated node features and node positions.
- Return type:
- message(x_i, x_j, pos_i, pos_j, edge_attr)[source]
Compute the message to be passed between nodes and edges.
- Parameters:
x_i (torch.Tensor | LabelTensor) – The node features of the recipient nodes.
x_j (torch.Tensor | LabelTensor) – The node features of the sender nodes.
pos_i (torch.Tensor | LabelTensor) – The node coordinates of the recipient nodes.
pos_j (torch.Tensor | LabelTensor) – The node coordinates of the sender nodes.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes.
- Returns:
The message to be passed.
- Return type:
- aggregate(inputs, index, ptr=None, dim_size=None)[source]
Aggregate the messages at the nodes during message passing.
This method receives a tuple of tensors corresponding to the messages to be aggregated. Both messages are aggregated separately according to the specified aggregation scheme.
- Parameters:
inputs (tuple(torch.Tensor)) – Tuple containing two messages to aggregate.
index (torch.Tensor | LabelTensor) – The indices of target nodes for each message. This tensor specifies which node each message is aggregated into.
ptr (torch.Tensor | LabelTensor) – Optional tensor to specify the slices of messages for each node (used in some aggregation strategies). Default is None.
dim_size (int) – Optional size of the output dimension, i.e., number of nodes. Default is None.
- Returns:
Tuple of aggregated tensors corresponding to (aggregated messages for position updates, aggregated messages for feature updates).
- Return type:
- update(aggregated_inputs, x, pos, edge_index)[source]
Update the node features and the node coordinates with the received messages.
- Parameters:
aggregated_inputs (tuple(torch.Tensor)) – The messages to be passed.
x (torch.Tensor | LabelTensor) – The node features.
pos (torch.Tensor | LabelTensor) – The euclidean coordinates of the nodes.
edge_index (torch.Tensor) – The edge indices.
- Returns:
The updated node features and node positions.
- Return type: