Generic Message Passing Neural Network (MPNN) framework.
Implements the message passing paradigm from "Neural Message Passing for Quantum Chemistry" (Gilmer et al., 2017). This module provides the building blocks for constructing graph neural networks where information propagates along edges.
Architecture
Node Features [batch, num_nodes, feature_dim]
Adjacency [batch, num_nodes, num_nodes]
|
v
+----------------------------+
| For each edge (i,j): |
| m_ij = msg_fn(h_i, h_j) |
+----------------------------+
|
v
+----------------------------+
| Aggregate per node: |
| M_i = AGG({m_ij : j->i}) |
+----------------------------+
|
v
+----------------------------+
| Update node features: |
| h_i' = update(h_i, M_i) |
+----------------------------+
|
v
Updated Node Features [batch, num_nodes, feature_dim']Graph Representation
Graphs are represented as dense adjacency matrices since Nx does not support sparse tensors. For large graphs, consider batching subgraphs.
- Adjacency:
{batch, num_nodes, num_nodes}- binary or weighted - Node features:
{batch, num_nodes, feature_dim}
Usage
# Generic message passing step
nodes = Axon.input("nodes", shape: {nil, 10, 64})
adj = Axon.input("adjacency", shape: {nil, 10, 10})
updated = MessagePassing.message_passing_layer(nodes, adj, 64,
name: "mpnn_1", activation: :relu)
# Graph-level output via pooling
graph_repr = MessagePassing.global_pool(updated, :mean)
Summary
Functions
Aggregate messages from neighbors using the specified method.
Pool node features to a graph-level representation.
Generic message passing step.
Types
Functions
@spec aggregate(Axon.t(), Axon.t(), aggregation()) :: Axon.t()
Aggregate messages from neighbors using the specified method.
This is a standalone aggregation function that operates on pre-computed
messages. For most use cases, message_passing_layer/4 handles aggregation
internally.
Parameters
messages- Message tensor{batch, num_nodes, num_neighbors, msg_dim}adjacency- Adjacency matrix{batch, num_nodes, num_nodes}mode- Aggregation mode: :sum, :mean, or :max
Returns
Aggregated messages {batch, num_nodes, msg_dim}.
@spec global_pool(Axon.t(), aggregation()) :: Axon.t()
Pool node features to a graph-level representation.
Reduces the node dimension by aggregating all node features in each graph into a single vector. Essential for graph classification tasks.
Parameters
node_features- Axon node with shape{batch, num_nodes, feature_dim}mode- Pooling mode: :sum, :mean, or :max (default: :mean)
Returns
Axon node with shape {batch, feature_dim}.
@spec message_passing_layer(Axon.t(), Axon.t(), pos_integer(), keyword()) :: Axon.t()
Generic message passing step.
For each edge (i, j) in the graph, computes a message from the features of nodes i and j, aggregates incoming messages per node, and updates node features via a learned transformation.
The message function concatenates sender and receiver features, then applies a dense layer. The update function concatenates the node's current features with aggregated messages, then applies a dense layer.
Parameters
nodes- Node features Axon node{batch, num_nodes, feature_dim}adjacency- Adjacency matrix Axon node{batch, num_nodes, num_nodes}output_dim- Output feature dimension per nodeopts- Options
Options
:name- Layer name prefix (default: "mpnn"):activation- Activation function (default: :relu):aggregation- Message aggregation: :sum, :mean, :max (default: :sum):dropout- Dropout rate (default: 0.0)
Returns
Axon node with shape {batch, num_nodes, output_dim}.