Edifice.Graph.MessagePassing (Edifice v0.2.0)

Copy Markdown View Source

Generic Message Passing Neural Network (MPNN) framework.

Implements the message passing paradigm from "Neural Message Passing for Quantum Chemistry" (Gilmer et al., 2017). This module provides the building blocks for constructing graph neural networks where information propagates along edges.

Architecture

Node Features [batch, num_nodes, feature_dim]
Adjacency     [batch, num_nodes, num_nodes]
      |
      v
+----------------------------+
| For each edge (i,j):       |
|   m_ij = msg_fn(h_i, h_j) |
+----------------------------+
      |
      v
+----------------------------+
| Aggregate per node:        |
|   M_i = AGG({m_ij : j->i}) |
+----------------------------+
      |
      v
+----------------------------+
| Update node features:      |
|   h_i' = update(h_i, M_i) |
+----------------------------+
      |
      v
Updated Node Features [batch, num_nodes, feature_dim']

Graph Representation

Graphs are represented as dense adjacency matrices since Nx does not support sparse tensors. For large graphs, consider batching subgraphs.

  • Adjacency: {batch, num_nodes, num_nodes} - binary or weighted
  • Node features: {batch, num_nodes, feature_dim}

Usage

# Generic message passing step
nodes = Axon.input("nodes", shape: {nil, 10, 64})
adj = Axon.input("adjacency", shape: {nil, 10, 10})

updated = MessagePassing.message_passing_layer(nodes, adj, 64,
  name: "mpnn_1", activation: :relu)

# Graph-level output via pooling
graph_repr = MessagePassing.global_pool(updated, :mean)

Summary

Functions

Aggregate messages from neighbors using the specified method.

Pool node features to a graph-level representation.

Types

aggregation()

@type aggregation() :: :sum | :mean | :max

Functions

aggregate(node_features, adjacency, mode)

@spec aggregate(Axon.t(), Axon.t(), aggregation()) :: Axon.t()

Aggregate messages from neighbors using the specified method.

This is a standalone aggregation function that operates on pre-computed messages. For most use cases, message_passing_layer/4 handles aggregation internally.

Parameters

  • messages - Message tensor {batch, num_nodes, num_neighbors, msg_dim}
  • adjacency - Adjacency matrix {batch, num_nodes, num_nodes}
  • mode - Aggregation mode: :sum, :mean, or :max

Returns

Aggregated messages {batch, num_nodes, msg_dim}.

global_pool(node_features, mode \\ :mean)

@spec global_pool(Axon.t(), aggregation()) :: Axon.t()

Pool node features to a graph-level representation.

Reduces the node dimension by aggregating all node features in each graph into a single vector. Essential for graph classification tasks.

Parameters

  • node_features - Axon node with shape {batch, num_nodes, feature_dim}
  • mode - Pooling mode: :sum, :mean, or :max (default: :mean)

Returns

Axon node with shape {batch, feature_dim}.

message_passing_layer(nodes, adjacency, output_dim, opts \\ [])

@spec message_passing_layer(Axon.t(), Axon.t(), pos_integer(), keyword()) :: Axon.t()

Generic message passing step.

For each edge (i, j) in the graph, computes a message from the features of nodes i and j, aggregates incoming messages per node, and updates node features via a learned transformation.

The message function concatenates sender and receiver features, then applies a dense layer. The update function concatenates the node's current features with aggregated messages, then applies a dense layer.

Parameters

  • nodes - Node features Axon node {batch, num_nodes, feature_dim}
  • adjacency - Adjacency matrix Axon node {batch, num_nodes, num_nodes}
  • output_dim - Output feature dimension per node
  • opts - Options

Options

  • :name - Layer name prefix (default: "mpnn")
  • :activation - Activation function (default: :relu)
  • :aggregation - Message aggregation: :sum, :mean, :max (default: :sum)
  • :dropout - Dropout rate (default: 0.0)

Returns

Axon node with shape {batch, num_nodes, output_dim}.