EquivariantGraphNeuralOperator#
- class EquivariantGraphNeuralOperator(n_egno_layers, node_feature_dim, edge_feature_dim, pos_dim, modes, time_steps=2, hidden_dim=64, time_emb_dim=16, max_time_idx=10000, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]#
Bases:
ModuleEquivariant Graph Neural Operator (EGNO) for modeling 3D dynamics.
EGNO is a graph-based neural operator that preserves equivariance with respect to 3D transformations while modeling temporal and spatial interactions between nodes. It combines:
Temporal convolution in the Fourier domain to capture long-range temporal dependencies efficiently.
Equivariant Graph Neural Network (EGNN) layers to model interactions between nodes while respecting geometric symmetries.
This design allows EGNO to learn complex spatiotemporal dynamics of physical systems, molecules, or particles while enforcing physically meaningful constraints.
See also
Original reference Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). Equivariant Graph Neural Operator for Modeling 3D Dynamics DOI: arXiv preprint arXiv:2401.11037.
Initialization of the
EquivariantGraphNeuralOperatorclass.- Parameters:
n_egno_layers (int) – The number of EGNO layers.
node_feature_dim (int) – The dimension of the node features in each EGNO layer.
edge_feature_dim (int) – The dimension of the edge features in each EGNO layer.
pos_dim (int) – The dimension of the position features in each EGNO layer.
modes (int) – The number of Fourier modes to use in the temporal convolution.
time_steps (int) – The number of time steps to consider in the temporal convolution. Default is 2.
hidden_dim (int) – The dimension of the hidden features in each EGNO layer. Default is 64.
time_emb_dim (int) – The dimension of the sinusoidal time embeddings. Default is 16.
max_time_idx (int) – The maximum time index for the sinusoidal embeddings. Default is 10000.
n_message_layers (int) – The number of layers in the message network of each EGNO layer. Default is 2.
n_update_layers (int) – The number of layers in the update network of each EGNO layer. Default is 2.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.SiLU.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassingfor more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassingfor more details. Default is “source_to_target”.
- Raises:
AssertionError – If
n_egno_layersis not a positive integer.AssertionError – If
time_emb_dimis not a positive integer.AssertionError – If
max_time_idxis not a positive integer.AssertionError – If
time_stepsis not a positive integer.
- forward(graph)[source]#
Forward pass of the
EquivariantGraphNeuralOperatorclass.- Parameters:
graph (Data | Graph) – The input graph object with the following attributes: - ‘x’: Node features, shape
[num_nodes, node_feature_dim]. - ‘pos’: Node positions, shape[num_nodes, pos_dim]. - ‘vel’: Node velocities, shape[num_nodes, pos_dim]. - ‘edge_index’: Graph connectivity, shape[2, num_edges]. - ‘edge_attr’: Edge attrs, shape[num_edges, edge_feature_dim].- Returns:
The output graph object with updated node features, positions, and velocities. The output graph adds to ‘x’, ‘pos’, ‘vel’, and ‘edge_attr’ the time dimension, resulting in shapes: - ‘x’:
[time_steps, num_nodes, node_feature_dim]- ‘pos’:[time_steps, num_nodes, pos_dim]- ‘vel’:[time_steps, num_nodes, pos_dim]- ‘edge_attr’:[time_steps, num_edges, edge_feature_dim]- Return type:
Data | Graph
- Raises:
ValueError – If the input graph does not have a ‘vel’ attribute.