Source code for pina.model.block.gno_block
"""Module for the Graph Neural Operator Block class."""
import torch
from torch_geometric.nn import MessagePassing
[docs]
class GNOBlock(MessagePassing):
"""
The inner block of the Graph Neural Operator, based on Message Passing.
"""
def __init__(
self,
width,
edges_features,
n_layers=2,
layers=None,
inner_size=None,
internal_func=None,
external_func=None,
):
"""
Initialization of the :class:`GNOBlock` class.
:param int width: The width of the kernel.
:param int edge_features: The number of edge features.
:param int n_layers: The number of kernel layers. Default is ``2``.
:param layers: A list specifying the number of neurons for each layer
of the neural network. If not ``None``, it overrides the
``inner_size`` and ``n_layers``parameters. Default is ``None``.
:type layers: list[int] | tuple[int]
:param int inner_size: The size of the inner layer. Default is ``None``.
:param torch.nn.Module internal_func: The activation function applied to
the output of each layer. If ``None``, it uses the
:class:`torch.nn.Tanh` activation. Default is ``None``.
:param torch.nn.Module external_func: The activation function applied to
the output of the block. If ``None``, it uses the
:class:`torch.nn.Tanh`. activation. Default is ``None``.
"""
from ...model.feed_forward import FeedForward
super().__init__(aggr="mean") # Uses PyG's default aggregation
self.width = width
if layers is None and inner_size is None:
inner_size = width
self.dense = FeedForward(
input_dimensions=edges_features,
output_dimensions=width**2,
n_layers=n_layers,
layers=layers,
inner_size=inner_size,
func=internal_func,
)
self.W = torch.nn.Linear(width, width)
self.func = external_func()
[docs]
def message_and_aggregate(self, edge_index, x, edge_attr):
"""
Combine messages and perform aggregation.
:param torch.Tensor edge_index: The edge index.
:param torch.Tensor x: The node feature matrix.
:param torch.Tensor edge_attr: The edge features.
:return: The aggregated messages.
:rtype: torch.Tensor
"""
# Edge features are transformed into a matrix of shape
# [num_edges, width, width]
x_ = self.dense(edge_attr).view(-1, self.width, self.width)
# Messages are computed as the product of the edge features
messages = torch.einsum("bij,bj->bi", x_, x[edge_index[0]])
# Aggregation is performed using the mean (set in the constructor)
return self.aggregate(messages, edge_index[1])
[docs]
def edge_update(self, edge_attr):
"""
Update edge features.
:param torch.Tensor edge_attr: The edge features.
:return: The updated edge features.
:rtype: torch.Tensor
"""
return edge_attr
[docs]
def update(self, aggr_out, x):
"""
Update node features.
:param torch.Tensor aggr_out: The aggregated messages.
:param torch.Tensor x: The node feature matrix.
:return: The updated node features.
:rtype: torch.Tensor
"""
return aggr_out + self.W(x)
[docs]
def forward(self, x, edge_index, edge_attr):
"""
Forward pass of the block.
:param torch.Tensor x: The node features.
:param torch.Tensor edge_index: The edge indeces.
:param torch.Tensor edge_attr: The edge features.
:return: The updated node features.
:rtype: torch.Tensor
"""
return self.func(self.propagate(edge_index, x=x, edge_attr=edge_attr))