Interaction Network Block#
- class InteractionNetworkBlock(node_feature_dim, edge_feature_dim=0, hidden_dim=64, n_message_layers=2, n_update_layers=2, activation=<class 'torch.nn.modules.activation.SiLU'>, aggr='add', node_dim=-2, flow='source_to_target')[source]
Bases:
MessagePassing
Implementation of the Interaction Network block.
This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Battaglia et al. in 2016. It serves as an inner block in a larger graph neural network architecture.
The message between two nodes connected by an edge is computed by applying a multi-layer perceptron (MLP) to the concatenation of the sender and recipient node features. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).
The update step is performed by applying another MLP to the concatenation of the incoming messages and the node features.
See also
Original reference: Battaglia, P. W., et al. (2016). Interaction Networks for Learning about Objects, Relations and Physics. In Advances in Neural Information Processing Systems (NeurIPS 2016). DOI: https://doi.org/10.48550/arXiv.1612.00222.
Initialization of the
InteractionNetworkBlock
class.- Parameters:
node_feature_dim (int) – The dimension of the node features.
edge_feature_dim (int) – The dimension of the edge features. If edge_attr is not provided, it is assumed to be 0. Default is 0.
hidden_dim (int) – The dimension of the hidden features. Default is 64.
n_message_layers (int) – The number of layers in the message network. Default is 2.
n_update_layers (int) – The number of layers in the update network. Default is 2.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.SiLU
.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassing
for more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassing
for more details. Default is “source_to_target”.
- Raises:
AssertionError – If
node_feature_dim
is not a positive integer.AssertionError – If
hidden_dim
is not a positive integer.AssertionError – If
n_message_layers
is not a positive integer.AssertionError – If
n_update_layers
is not a positive integer.AssertionError – If
edge_feature_dim
is not a non-negative integer.
- forward(x, edge_index, edge_attr=None)[source]
Forward pass of the block, triggering the message-passing routine.
- Parameters:
x (torch.Tensor | LabelTensor) – The node features.
edge_index (torch.Tensor) – The edge indeces.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes. Default is None.
- Returns:
The updated node features.
- Return type:
- message(x_i, x_j, edge_attr)[source]
Compute the message to be passed between nodes and edges.
- Parameters:
x_i (torch.Tensor | LabelTensor) – The node features of the recipient nodes.
x_j (torch.Tensor | LabelTensor) – The node features of the sender nodes.
- Returns:
The message to be passed.
- Return type:
- update(message, x)[source]
Update the node features with the received messages.
- Parameters:
message (torch.Tensor) – The message to be passed.
x (torch.Tensor | LabelTensor) – The node features.
- Returns:
The updated node features.
- Return type: