Interaction Network Block#

class InteractionNetworkBlock(node_feature_dim, edge_feature_dim=0, hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]

Bases: MessagePassing

Implementation of the Interaction Network block.

This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Battaglia et al. in 2016. It serves as an inner block in a larger graph neural network architecture.

The message between two nodes connected by an edge is computed by applying a multi-layer perceptron (MLP) to the concatenation of the sender and recipient node features. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).

The update step is performed by applying another MLP to the concatenation of the incoming messages and the node features.

See also

Original reference: Battaglia, P. W., et al. (2016). Interaction Networks for Learning about Objects, Relations and Physics. In Advances in Neural Information Processing Systems (NeurIPS 2016). DOI: https://doi.org/10.48550/arXiv.1612.00222.

Initialization of the InteractionNetworkBlock class.

Parameters:
  • node_feature_dim (int) – The dimension of the node features.

  • edge_feature_dim (int) – The dimension of the edge features. If edge_attr is not provided, it is assumed to be 0. Default is 0.

  • hidden_dim (int) – The dimension of the hidden features. Default is 64.

  • n_message_layers (int) – The number of layers in the message network. Default is 2.

  • n_update_layers (int) – The number of layers in the update network. Default is 2.

  • activation (torch.nn.Module) – The activation function. Default is torch.nn.SiLU.

  • aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See torch_geometric.nn.MessagePassing for more details. Default is “add”.

  • node_dim (int) – The axis along which to propagate. Default is -2.

  • flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See torch_geometric.nn.MessagePassing for more details. Default is “source_to_target”.

Raises:
forward(x, edge_index, edge_attr=None)[source]

Forward pass of the block, triggering the message-passing routine.

Parameters:
Returns:

The updated node features.

Return type:

torch.Tensor

message(x_i, x_j, edge_attr)[source]

Compute the message to be passed between nodes and edges.

Parameters:
Returns:

The message to be passed.

Return type:

torch.Tensor

update(message, x)[source]

Update the node features with the received messages.

Parameters:
Returns:

The updated node features.

Return type:

torch.Tensor