E(n) Equivariant Network Block#
- class EnEquivariantNetworkBlock(node_feature_dim, edge_feature_dim, pos_dim, hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]
Bases:
MessagePassingImplementation of the E(n) Equivariant Graph Neural Network block. This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Satorras et al. in 2021. It serves as an inner block in a larger graph neural network architecture.
The message between two nodes connected by an edge is computed by applying a linear transformation to the sender node features and the edge features, together with the squared euclidean distance between the sender and recipient node positions, followed by a non-linear activation function. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).
The update step is performed by applying another MLP to the concatenation of the incoming messages and the node features. Here, also the node positions are updated by adding the incoming messages divided by the degree of the recipient node.
See also
Original reference Satorras, V. G., Hoogeboom, E., Welling, M. (2021). E(n) Equivariant Graph Neural Networks. In International Conference on Machine Learning. DOI: https://doi.org/10.48550/arXiv.2102.09844.
Initialization of the
EnEquivariantNetworkBlockclass.- Parameters:
node_feature_dim (int) – The dimension of the node features.
edge_feature_dim (int) – The dimension of the edge features.
pos_dim (int) – The dimension of the position features.
hidden_dim (int) – The dimension of the hidden features. Default is 64.
n_message_layers (int) – The number of layers in the message network. Default is 2.
n_update_layers (int) – The number of layers in the update network. Default is 2.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.SiLU.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassingfor more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassingfor more details. Default is “source_to_target”.
- Raises:
AssertionError – If
node_feature_dimis not a positive integer.AssertionError – If
edge_feature_dimis a negative integer.AssertionError – If
pos_dimis not a positive integer.AssertionError – If
hidden_dimis not a positive integer.AssertionError – If
n_message_layersis not a positive integer.AssertionError – If
n_update_layersis not a positive integer.
- forward(x, pos, edge_index, edge_attr=None)[source]
Forward pass of the block, triggering the message-passing routine.
- Parameters:
x (torch.Tensor | LabelTensor) – The node features.
pos (torch.Tensor | LabelTensor) – The euclidean coordinates of the nodes.
edge_index (torch.Tensor) – The edge indices.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes. Default is None.
- Returns:
The updated node features and node positions.
- Return type:
- message(x_i, x_j, pos_i, pos_j, edge_attr)[source]
Compute the message to be passed between nodes and edges.
- Parameters:
x_i (torch.Tensor | LabelTensor) – The node features of the recipient nodes.
x_j (torch.Tensor | LabelTensor) – The node features of the sender nodes.
pos_i (torch.Tensor | LabelTensor) – The node coordinates of the recipient nodes.
pos_j (torch.Tensor | LabelTensor) – The node coordinates of the sender nodes.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes.
- Returns:
The message to be passed.
- Return type:
- aggregate(inputs, index, ptr=None, dim_size=None)[source]
Aggregate the messages at the nodes during message passing.
This method receives a tuple of tensors corresponding to the messages to be aggregated. Both messages are aggregated separately according to the specified aggregation scheme.
- Parameters:
inputs (tuple(torch.Tensor)) – Tuple containing two messages to aggregate.
index (torch.Tensor | LabelTensor) – The indices of target nodes for each message. This tensor specifies which node each message is aggregated into.
ptr (torch.Tensor | LabelTensor) – Optional tensor to specify the slices of messages for each node (used in some aggregation strategies). Default is None.
dim_size (int) – Optional size of the output dimension, i.e., number of nodes. Default is None.
- Returns:
Tuple of aggregated tensors corresponding to (aggregated messages for position updates, aggregated messages for feature updates).
- Return type:
- update(aggregated_inputs, x, pos, edge_index)[source]
Update the node features and the node coordinates with the received messages.
- Parameters:
aggregated_inputs (tuple(torch.Tensor)) – The messages to be passed.
x (torch.Tensor | LabelTensor) – The node features.
pos (torch.Tensor | LabelTensor) – The euclidean coordinates of the nodes.
edge_index (torch.Tensor) – The edge indices.
- Returns:
The updated node features and node positions.
- Return type: