"""Module for the Equivariant Graph Neural Operator block."""
import torch
from ....utils import check_positive_integer
from .en_equivariant_network_block import EnEquivariantNetworkBlock
[docs]
class EquivariantGraphNeuralOperatorBlock(torch.nn.Module):
"""
A single block of the Equivariant Graph Neural Operator (EGNO).
This block combines a temporal convolution with an equivariant graph neural
network (EGNN) layer. It preserves equivariance while modeling complex
interactions between nodes in a graph over time.
.. seealso::
**Original reference**
Xu, M., Han, J., Lou, A., Kossaifi, J., Ramanathan, A., Azizzadenesheli,
K., Leskovec, J., Ermon, S., Anandkumar, A. (2024).
*Equivariant Graph Neural Operator for Modeling 3D Dynamics*
DOI: `arXiv preprint arXiv:2401.11037.
<https://arxiv.org/abs/2401.11037>`_
"""
def __init__(
self,
node_feature_dim,
edge_feature_dim,
pos_dim,
modes,
hidden_dim=64,
n_message_layers=2,
n_update_layers=2,
activation=torch.nn.SiLU,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`EquivariantGraphNeuralOperatorBlock`
class.
:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param int pos_dim: The dimension of the position features.
:param int modes: The number of Fourier modes to use in the temporal
convolution.
:param int hidden_dim: The dimension of the hidden features.
Default is 64.
:param int n_message_layers: The number of layers in the message
network. Default is 2.
:param int n_update_layers: The number of layers in the update network.
Default is 2.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.SiLU`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If ``modes`` is not a positive integer.
"""
super().__init__()
# Check consistency
check_positive_integer(modes, strict=True)
# Initialization
self.modes = modes
# Temporal convolution weights - real and imaginary parts
self.weight_scalar_r = torch.nn.Parameter(
torch.rand(node_feature_dim, node_feature_dim, modes)
)
self.weight_scalar_i = torch.nn.Parameter(
torch.rand(node_feature_dim, node_feature_dim, modes)
)
self.weight_vector_r = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
self.weight_vector_i = torch.nn.Parameter(torch.rand(2, 2, modes) * 0.1)
# EGNN block
self.egnn = EnEquivariantNetworkBlock(
node_feature_dim=node_feature_dim,
edge_feature_dim=edge_feature_dim,
pos_dim=pos_dim,
use_velocity=True,
hidden_dim=hidden_dim,
n_message_layers=n_message_layers,
n_update_layers=n_update_layers,
activation=activation,
aggr=aggr,
node_dim=node_dim,
flow=flow,
)
[docs]
def forward(self, x, pos, vel, edge_index, edge_attr=None):
"""
Forward pass of the Equivariant Graph Neural Operator block.
:param x: The node feature tensor of shape
``[time_steps, num_nodes, node_feature_dim]``.
:type x: torch.Tensor | LabelTensor
:param pos: The node position tensor (Euclidean coordinates) of shape
``[time_steps, num_nodes, pos_dim]``.
:type pos: torch.Tensor | LabelTensor
:param vel: The node velocity tensor of shape
``[time_steps, num_nodes, pos_dim]``.
:type vel: torch.Tensor | LabelTensor
:param edge_index: The edge connectivity of shape ``[2, num_edges]``.
:type edge_index: torch.Tensor
:param edge_attr: The edge feature tensor of shape
``[time_steps, num_edges, edge_feature_dim]``. Default is None.
:type edge_attr: torch.Tensor | LabelTensor, optional
:return: The updated node features, positions, and velocities, each with
the same shape as the inputs.
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
# Prepare features
center = pos.mean(dim=1, keepdim=True)
vector = torch.stack((pos - center, vel), dim=-1)
# Compute temporal convolution
x = x + self._convolution(
x, "mni, iom -> mno", self.weight_scalar_r, self.weight_scalar_i
)
vector = vector + self._convolution(
vector,
"mndi, iom -> mndo",
self.weight_vector_r,
self.weight_vector_i,
)
# Split position and velocity
pos, vel = vector.unbind(dim=-1)
pos = pos + center
# Reshape to (time * nodes, feature) for egnn
x = x.reshape(-1, x.shape[-1])
pos = pos.reshape(-1, pos.shape[-1])
vel = vel.reshape(-1, vel.shape[-1])
if edge_attr is not None:
edge_attr = edge_attr.reshape(-1, edge_attr.shape[-1])
x, pos, vel = self.egnn(
x=x,
pos=pos,
edge_index=edge_index,
edge_attr=edge_attr,
vel=vel,
)
# Reshape back to (time, nodes, feature)
x = x.reshape(center.shape[0], -1, x.shape[-1])
pos = pos.reshape(center.shape[0], -1, pos.shape[-1])
vel = vel.reshape(center.shape[0], -1, vel.shape[-1])
return x, pos, vel
def _convolution(self, x, einsum_idx, real, img):
"""
Compute the temporal convolution.
:param torch.Tensor x: The input features.
:param str einsum_idx: The indices for the einsum operation.
:param torch.Tensor real: The real part of the convolution weights.
:param torch.Tensor img: The imaginary part of the convolution weights.
:return: The convolved features.
:rtype: torch.Tensor
"""
# Number of modes to use
modes = min(self.modes, (x.shape[0] // 2) + 1)
# Build complex weights
weights = torch.complex(real[..., :modes], img[..., :modes])
# Convolution in Fourier space
fourier = torch.fft.rfftn(x, dim=[0])[:modes]
out = torch.einsum(einsum_idx, fourier, weights)
return torch.fft.irfftn(out, s=x.shape[0], dim=0)