Deep Tensor Network Block#
- class DeepTensorNetworkBlock(node_feature_dim, edge_feature_dim, activation=<class 'torch.nn.modules.activation.Tanh'>, aggr='add', node_dim=-2, flow='source_to_target')[source]
Bases:
MessagePassing
Implementation of the Deep Tensor Network block.
This block is used to perform message-passing between nodes and edges in a graph neural network, following the scheme proposed by Schutt et al. in 2017. It serves as an inner block in a larger graph neural network architecture.
The message between two nodes connected by an edge is computed by applying a linear transformation to the sender node features and the edge features, followed by a non-linear activation function. Messages are then aggregated using an aggregation scheme (e.g., sum, mean, min, max, or product).
The update step is performed by a simple addition of the incoming messages to the node features.
See also
Original reference: Schutt, K., Arbabzadah, F., Chmiela, S. et al. (2017). Quantum-Chemical Insights from Deep Tensor Neural Networks. Nature Communications 8, 13890 (2017). DOI: https://doi.org/10.1038/ncomms13890.
Initialization of the
DeepTensorNetworkBlock
class.- Parameters:
node_feature_dim (int) – The dimension of the node features.
edge_feature_dim (int) – The dimension of the edge features.
activation (torch.nn.Module) – The activation function. Default is
torch.nn.Tanh
.aggr (str) – The aggregation scheme to use for message passing. Available options are “add”, “mean”, “min”, “max”, “mul”. See
torch_geometric.nn.MessagePassing
for more details. Default is “add”.node_dim (int) – The axis along which to propagate. Default is -2.
flow (str) – The direction of message passing. Available options are “source_to_target” and “target_to_source”. The “source_to_target” flow means that messages are sent from the source node to the target node, while the “target_to_source” flow means that messages are sent from the target node to the source node. See
torch_geometric.nn.MessagePassing
for more details. Default is “source_to_target”.
- Raises:
AssertionError – If
node_feature_dim
is not a positive integer.AssertionError – If
edge_feature_dim
is not a positive integer.
- forward(x, edge_index, edge_attr)[source]
Forward pass of the block, triggering the message-passing routine.
- Parameters:
x (torch.Tensor | LabelTensor) – The node features.
edge_index (torch.Tensor) – The edge indeces.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes.
- Returns:
The updated node features.
- Return type:
- message(x_j, edge_attr)[source]
Compute the message to be passed between nodes and edges.
- Parameters:
x_j (torch.Tensor | LabelTensor) – The node features of the sender nodes.
edge_attr (torch.Tensor | LabelTensor) – The edge attributes.
- Returns:
The message to be passed.
- Return type:
- update(message, x)[source]
Update the node features with the received messages.
- Parameters:
message (torch.Tensor) – The message to be passed.
x (torch.Tensor | LabelTensor) – The node features.
- Returns:
The updated node features.
- Return type: