Source code for pina.model.equivariant_graph_neural_operator

"""Module for the Equivariant Graph Neural Operator model."""

import torch
from ..utils import check_positive_integer
from .block.message_passing import EquivariantGraphNeuralOperatorBlock


[docs] class EquivariantGraphNeuralOperator(torch.nn.Module): """ Equivariant Graph Neural Operator (EGNO) for modeling 3D dynamics. EGNO is a graph-based neural operator that preserves equivariance with respect to 3D transformations while modeling temporal and spatial interactions between nodes. It combines: 1. Temporal convolution in the Fourier domain to capture long-range temporal dependencies efficiently. 2. Equivariant Graph Neural Network (EGNN) layers to model interactions between nodes while respecting geometric symmetries. This design allows EGNO to learn complex spatiotemporal dynamics of physical systems, molecules, or particles while enforcing physically meaningful constraints. .. seealso:: **Original reference** Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli, K., Leskovec, J., Ermon, S., Anandkumar, A. (2024). *Equivariant Graph Neural Operator for Modeling 3D Dynamics* DOI: `arXiv preprint arXiv:2401.11037. <https://arxiv.org/abs/2401.11037>`_ """ def __init__( self, n_egno_layers, node_feature_dim, edge_feature_dim, pos_dim, modes, time_steps=2, hidden_dim=64, time_emb_dim=16, max_time_idx=10000, n_message_layers=2, n_update_layers=2, activation=torch.nn.SiLU, aggr="add", node_dim=-2, flow="source_to_target", ): """ Initialization of the :class:`EquivariantGraphNeuralOperator` class. :param int n_egno_layers: The number of EGNO layers. :param int node_feature_dim: The dimension of the node features in each EGNO layer. :param int edge_feature_dim: The dimension of the edge features in each EGNO layer. :param int pos_dim: The dimension of the position features in each EGNO layer. :param int modes: The number of Fourier modes to use in the temporal convolution. :param int time_steps: The number of time steps to consider in the temporal convolution. Default is 2. :param int hidden_dim: The dimension of the hidden features in each EGNO layer. Default is 64. :param int time_emb_dim: The dimension of the sinusoidal time embeddings. Default is 16. :param int max_time_idx: The maximum time index for the sinusoidal embeddings. Default is 10000. :param int n_message_layers: The number of layers in the message network of each EGNO layer. Default is 2. :param int n_update_layers: The number of layers in the update network of each EGNO layer. Default is 2. :param torch.nn.Module activation: The activation function. Default is :class:`torch.nn.SiLU`. :param str aggr: The aggregation scheme to use for message passing. Available options are "add", "mean", "min", "max", "mul". See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "add". :param int node_dim: The axis along which to propagate. Default is -2. :param str flow: The direction of message passing. Available options are "source_to_target" and "target_to_source". The "source_to_target" flow means that messages are sent from the source node to the target node, while the "target_to_source" flow means that messages are sent from the target node to the source node. See :class:`torch_geometric.nn.MessagePassing` for more details. Default is "source_to_target". :raises AssertionError: If ``n_egno_layers`` is not a positive integer. :raises AssertionError: If ``time_emb_dim`` is not a positive integer. :raises AssertionError: If ``max_time_idx`` is not a positive integer. :raises AssertionError: If ``time_steps`` is not a positive integer. """ super().__init__() # Check consistency check_positive_integer(n_egno_layers, strict=True) check_positive_integer(time_emb_dim, strict=True) check_positive_integer(max_time_idx, strict=True) check_positive_integer(time_steps, strict=True) # Initialize parameters self.time_steps = time_steps self.time_emb_dim = time_emb_dim self.max_time_idx = max_time_idx # Initialize EGNO layers self.egno_layers = torch.nn.ModuleList() for _ in range(n_egno_layers): self.egno_layers.append( EquivariantGraphNeuralOperatorBlock( node_feature_dim=node_feature_dim, edge_feature_dim=edge_feature_dim, pos_dim=pos_dim, modes=modes, hidden_dim=hidden_dim, n_message_layers=n_message_layers, n_update_layers=n_update_layers, activation=activation, aggr=aggr, node_dim=node_dim, flow=flow, ) ) # Linear layer to adjust the scalar feature dimension self.linear = torch.nn.Linear( node_feature_dim + time_emb_dim, node_feature_dim )
[docs] def forward(self, graph): """ Forward pass of the :class:`EquivariantGraphNeuralOperator` class. :param graph: The input graph object with the following attributes: - 'x': Node features, shape ``[num_nodes, node_feature_dim]``. - 'pos': Node positions, shape ``[num_nodes, pos_dim]``. - 'vel': Node velocities, shape ``[num_nodes, pos_dim]``. - 'edge_index': Graph connectivity, shape ``[2, num_edges]``. - 'edge_attr': Edge attrs, shape ``[num_edges, edge_feature_dim]``. :type graph: Data | Graph :return: The output graph object with updated node features, positions, and velocities. The output graph adds to 'x', 'pos', 'vel', and 'edge_attr' the time dimension, resulting in shapes: - 'x': ``[time_steps, num_nodes, node_feature_dim]`` - 'pos': ``[time_steps, num_nodes, pos_dim]`` - 'vel': ``[time_steps, num_nodes, pos_dim]`` - 'edge_attr': ``[time_steps, num_edges, edge_feature_dim]`` :rtype: Data | Graph :raises ValueError: If the input graph does not have a 'vel' attribute. """ # Check that the graph has the required attributes if "vel" not in graph: raise ValueError("The input graph must have a 'vel' attribute.") # Compute the temporal embedding emb = self._embedding(torch.arange(self.time_steps)).to(graph.x.device) emb = emb.unsqueeze(1).repeat(1, graph.x.shape[0], 1) # Expand dimensions x = graph.x.unsqueeze(0).repeat(self.time_steps, 1, 1) x = self.linear(torch.cat((x, emb), dim=-1)) pos = graph.pos.unsqueeze(0).repeat(self.time_steps, 1, 1) vel = graph.vel.unsqueeze(0).repeat(self.time_steps, 1, 1) # Manage edge index offset = torch.arange(self.time_steps).reshape(-1, 1) offset = offset.to(graph.x.device) * graph.x.shape[0] src = graph.edge_index[0].unsqueeze(0) + offset dst = graph.edge_index[1].unsqueeze(0) + offset edge_index = torch.stack([src, dst], dim=0).reshape(2, -1) # Manage edge attributes if graph.edge_attr is not None: edge_attr = graph.edge_attr.unsqueeze(0) edge_attr = edge_attr.repeat(self.time_steps, 1, 1) else: edge_attr = None # Iteratively apply EGNO layers for layer in self.egno_layers: x, pos, vel = layer( x=x, pos=pos, vel=vel, edge_index=edge_index, edge_attr=edge_attr, ) # Build new graph new_graph = graph.clone() new_graph.x, new_graph.pos, new_graph.vel = x, pos, vel if edge_attr is not None: new_graph.edge_attr = edge_attr return new_graph
def _embedding(self, time): """ Generate sinusoidal temporal embeddings. :param torch.Tensor time: The time instances. :return: The sinusoidal embedding tensor. :rtype: torch.Tensor """ # Compute the sinusoidal embeddings half_dim = self.time_emb_dim // 2 logs = torch.log(torch.as_tensor(self.max_time_idx)) / (half_dim - 1) freqs = torch.exp(-torch.arange(half_dim) * logs) args = torch.as_tensor(time)[:, None] * freqs[None, :] emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # Apply padding if the embedding dimension is odd if self.time_emb_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1), mode="constant") return emb