Radial Field Network Block#

class RadialFieldNetworkBlock(node_feature_dim, hidden_dim=64, n_layers=2, activation=<class 'torch.nn.modules.activation.Tanh'>, aggr='add', node_dim=-2, flow='source_to_target')[source]

Bases: MessagePassing

Implementation of the Radial Field Network block.

This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Köhler et al. in 2020. It serves as an inner block in a larger graph neural network architecture.

The message between two nodes connected by an edge is computed by applying a linear transformation to the norm of the difference between the sender and recipient node features, together with the radial distance between the sender and recipient node features, followed by a non-linear activation function. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).

The update step is performed by a simple addition of the incoming messages to the node features.

See also

Original reference Köhler, J., Klein, L., Noé, F. (2020). Equivariant Flows: Exact Likelihood Generative Learning for Symmetric Densities. In International Conference on Machine Learning. DOI: https://doi.org/10.48550/arXiv.2006.02425.

Initialization of the RadialFieldNetworkBlock class.

Parameters:
  • node_feature_dim (int) – The dimension of the node features.

  • hidden_dim (int) – The dimension of the hidden features. Default is 64.

  • n_layers (int) – The number of layers in the network. Default is 2.

  • activation (torch.nn.Module) – The activation function. Default is torch.nn.Tanh.

  • aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See torch_geometric.nn.MessagePassing for more details. Default is “add”.

  • node_dim (int) – The axis along which to propagate. Default is -2.

  • flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See torch_geometric.nn.MessagePassing for more details. Default is “source_to_target”.

Raises:
forward(x, edge_index)[source]

Forward pass of the block, triggering the message-passing routine.

Parameters:
Returns:

The updated node features.

Return type:

torch.Tensor

message(x_i, x_j)[source]

Compute the message to be passed between nodes and edges.

Parameters:
Returns:

The message to be passed.

Return type:

torch.Tensor

update(message, x)[source]

Update the node features with the received messages.

Parameters:
Returns:

The updated node features.

Return type:

torch.Tensor