EquivariantGraphNeuralOperatorBlock#
- class EquivariantGraphNeuralOperatorBlock(node_feature_dim, edge_feature_dim, pos_dim, modes, hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]#
Bases:
ModuleA single block of the Equivariant Graph Neural Operator (EGNO).
This block combines a temporal convolution with an equivariant graph neural network (EGNN) layer. It preserves equivariance while modeling complex interactions between nodes in a graph over time.
See also
Original reference Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). Equivariant Graph Neural Operator for Modeling 3D Dynamics DOI: arXiv preprint arXiv:2401.11037.
Initialization of the
EquivariantGraphNeuralOperatorBlockclass.- Parameters:
node_feature_dim (int) – The dimension of the node features.
edge_feature_dim (int) – The dimension of the edge features.
pos_dim (int) – The dimension of the position features.
modes (int) – The number of Fourier modes to use in the temporal convolution.
hidden_dim (int) – The dimension of the hidden features. Default is 64.
n_message_layers (int) – The number of layers in the message network. Default is 2.
n_update_layers (int) – The number of layers in the update network. Default is 2.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.SiLU.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassingfor more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassingfor more details. Default is “source_to_target”.
- Raises:
AssertionError – If
modesis not a positive integer.
- forward(x, pos, vel, edge_index, edge_attr=None)[source]#
Forward pass of the Equivariant Graph Neural Operator block.
- Parameters:
x (torch.Tensor | LabelTensor) – The node feature tensor of shape
[time_steps, num_nodes, node_feature_dim].pos (torch.Tensor | LabelTensor) – The node position tensor (Euclidean coordinates) of shape
[time_steps, num_nodes, pos_dim].vel (torch.Tensor | LabelTensor) – The node velocity tensor of shape
[time_steps, num_nodes, pos_dim].edge_index (torch.Tensor) – The edge connectivity of shape
[2, num_edges].edge_attr (torch.Tensor | LabelTensor, optional) – The edge feature tensor of shape
[time_steps, num_edges, edge_feature_dim]. Default is None.
- Returns:
The updated node features, positions, and velocities, each with the same shape as the inputs.
- Return type: