EquivariantGraphNeuralOperator#

class EquivariantGraphNeuralOperator(n_egno_layers, node_feature_dim, edge_feature_dim, pos_dim, modes, time_steps=2, hidden_dim=64, time_emb_dim=16, max_time_idx=10000, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]#

Bases: Module

Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.

EGNO is a graph-based neural operator that preserves equivariance with respect to 3D transformations while modeling temporal and spatial interactions between nodes. It combines:

  1. Temporal convolution in the Fourier domain to capture long-range temporal dependencies efficiently.

  2. Equivariant Graph Neural Network (EGNN) layers to model interactions between nodes while respecting geometric symmetries.

This design allows EGNO to learn complex spatiotemporal dynamics of physical systems, molecules, or particles while enforcing physically meaningful constraints.

See also

Original reference Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). Equivariant Graph Neural Operator for Modeling 3D Dynamics DOI: arXiv preprint arXiv:2401.11037.

Initialization of the EquivariantGraphNeuralOperator class.

Parameters:
  • n_egno_layers (int) – The number of EGNO layers.

  • node_feature_dim (int) – The dimension of the node features in each EGNO layer.

  • edge_feature_dim (int) – The dimension of the edge features in each EGNO layer.

  • pos_dim (int) – The dimension of the position features in each EGNO layer.

  • modes (int) – The number of Fourier modes to use in the temporal convolution.

  • time_steps (int) – The number of time steps to consider in the temporal convolution. Default is 2.

  • hidden_dim (int) – The dimension of the hidden features in each EGNO layer. Default is 64.

  • time_emb_dim (int) – The dimension of the sinusoidal time embeddings. Default is 16.

  • max_time_idx (int) – The maximum time index for the sinusoidal embeddings. Default is 10000.

  • n_message_layers (int) – The number of layers in the message network of each EGNO layer. Default is 2.

  • n_update_layers (int) – The number of layers in the update network of each EGNO layer. Default is 2.

  • activation (torch.nn.Module) – The activation function. Default is torch.nn.SiLU.

  • aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See torch_geometric.nn.MessagePassing for more details. Default is “add”.

  • node_dim (int) – The axis along which to propagate. Default is -2.

  • flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See torch_geometric.nn.MessagePassing for more details. Default is “source_to_target”.

Raises:
forward(graph)[source]#

Forward pass of the EquivariantGraphNeuralOperator class.

Parameters:

graph (Data | Graph) – The input graph object with the following attributes: - ‘x’: Node features, shape [num_nodes, node_feature_dim]. - ‘pos’: Node positions, shape [num_nodes, pos_dim]. - ‘vel’: Node velocities, shape [num_nodes, pos_dim]. - ‘edge_index’: Graph connectivity, shape [2, num_edges]. - ‘edge_attr’: Edge attrs, shape [num_edges, edge_feature_dim].

Returns:

The output graph object with updated node features, positions, and velocities. The output graph adds to ‘x’, ‘pos’, ‘vel’, and ‘edge_attr’ the time dimension, resulting in shapes: - ‘x’: [time_steps, num_nodes, node_feature_dim] - ‘pos’: [time_steps, num_nodes, pos_dim] - ‘vel’: [time_steps, num_nodes, pos_dim] - ‘edge_attr’: [time_steps, num_edges, edge_feature_dim]

Return type:

Data | Graph

Raises:

ValueError – If the input graph does not have a ‘vel’ attribute.