Radial Field Network Block#
- class RadialFieldNetworkBlock(node_feature_dim, hidden_dim=64, n_layers=2, activation=<class 'torch.nn.modules.activation.Tanh'>, aggr='add', node_dim=-2, flow='source_to_target')[source]
Bases:
MessagePassing
Implementation of the Radial Field Network block.
This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Köhler et al. in 2020. It serves as an inner block in a larger graph neural network architecture.
The message between two nodes connected by an edge is computed by applying a linear transformation to the norm of the difference between the sender and recipient node features, together with the radial distance between the sender and recipient node features, followed by a non-linear activation function. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).
The update step is performed by a simple addition of the incoming messages to the node features.
See also
Original reference Köhler, J., Klein, L., Noé, F. (2020). Equivariant Flows: Exact Likelihood Generative Learning for Symmetric Densities. In International Conference on Machine Learning. DOI: https://doi.org/10.48550/arXiv.2006.02425.
Initialization of the
RadialFieldNetworkBlock
class.- Parameters:
node_feature_dim (int) – The dimension of the node features.
hidden_dim (int) – The dimension of the hidden features. Default is 64.
n_layers (int) – The number of layers in the network. Default is 2.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.Tanh
.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassing
for more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassing
for more details. Default is “source_to_target”.
- Raises:
AssertionError – If
node_feature_dim
is not a positive integer.AssertionError – If
hidden_dim
is not a positive integer.AssertionError – If
n_layers
is not a positive integer.
- forward(x, edge_index)[source]
Forward pass of the block, triggering the message-passing routine.
- Parameters:
x (torch.Tensor | LabelTensor) – The node features.
edge_index (torch.Tensor) – The edge indices.
- Returns:
The updated node features.
- Return type:
- message(x_i, x_j)[source]
Compute the message to be passed between nodes and edges.
- Parameters:
x_i (torch.Tensor | LabelTensor) – The node features of the recipient nodes.
x_j (torch.Tensor | LabelTensor) – The node features of the sender nodes.
- Returns:
The message to be passed.
- Return type:
- update(message, x)[source]
Update the node features with the received messages.
- Parameters:
message (torch.Tensor) – The message to be passed.
x (torch.Tensor | LabelTensor) – The node features.
- Returns:
The updated node features.
- Return type: