Edifice.Graph.GCN (Edifice v0.2.0)

Copy Markdown View Source

Graph Convolutional Network (Kipf & Welling, 2017).

Implements spectral graph convolutions approximated by first-order Chebyshev polynomials. Each GCN layer propagates and transforms node features using the graph structure via the normalized adjacency matrix.

Architecture

Node Features [batch, num_nodes, input_dim]
Adjacency     [batch, num_nodes, num_nodes]
      |
      v
+------------------------------------+
| GCN Layer 1:                       |
|   H' = sigma(D^-1/2 A D^-1/2 H W) |
+------------------------------------+
      |
      v
+------------------------------------+
| GCN Layer 2:                       |
|   H' = sigma(D^-1/2 A D^-1/2 H W) |
+------------------------------------+
      |
      v
Node Embeddings [batch, num_nodes, hidden_dim]

The normalized adjacency D^{-1/2} A D^{-1/2} ensures symmetric normalization that prevents feature magnitudes from scaling with node degree.

Graph Classification

For graph-level tasks, use build_classifier/1 which adds global pooling and a dense classification head on top of the GCN layers.

Usage

# Node classification
model = GCN.build(input_dim: 16, hidden_dims: [64, 32], num_classes: 7)

# Graph classification
model = GCN.build_classifier(
  input_dim: 16,
  hidden_dims: [64, 64],
  num_classes: 2,
  pool: :mean
)

References

  • "Semi-Supervised Classification with Graph Convolutional Networks" (Kipf & Welling, ICLR 2017)

Summary

Types

Options for build/1.

Functions

Build a Graph Convolutional Network.

Build a GCN with global pooling and dense classifier for graph classification.

Single Graph Convolutional layer.

Types

build_opt()

@type build_opt() ::
  {:activation, atom()}
  | {:dropout, float()}
  | {:hidden_dims, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:num_classes, pos_integer() | nil}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Graph Convolutional Network.

Stacks multiple GCN layers that propagate features along edges. The final output is per-node embeddings suitable for node classification or as input to a graph-level pooling layer.

Options

  • :input_dim - Input feature dimension per node (required)
  • :hidden_dims - List of hidden dimensions for each GCN layer (default: [64, 64])
  • :num_classes - If provided, adds a final classification layer (default: nil)
  • :activation - Activation function (default: :relu)
  • :dropout - Dropout rate between layers (default: 0.0)

Returns

An Axon model with two inputs ("nodes" and "adjacency"). Output shape is {batch, num_nodes, last_hidden_dim} or {batch, num_nodes, num_classes} if :num_classes is set.

build_classifier(opts \\ [])

@spec build_classifier(keyword()) :: Axon.t()

Build a GCN with global pooling and dense classifier for graph classification.

Adds a global pooling layer after the GCN layers to produce a single vector per graph, followed by a dense classification head.

Options

All options from build/1, plus:

  • :pool - Global pooling mode: :sum, :mean, :max (default: :mean)
  • :classifier_dims - Hidden dims for classification MLP (default: [64])

Returns

An Axon model outputting {batch, num_classes}.

gcn_layer(nodes, adjacency, output_dim, opts \\ [])

@spec gcn_layer(Axon.t(), Axon.t(), pos_integer(), keyword()) :: Axon.t()

Single Graph Convolutional layer.

Implements the spectral convolution rule:

H' = sigma(D^{-1/2} A D^{-1/2} H W)

where A is the adjacency matrix (with self-loops added), D is the degree matrix, H is the node feature matrix, and W is a learnable weight matrix.

Parameters

  • nodes - Node features Axon node {batch, num_nodes, in_dim}
  • adjacency - Adjacency matrix Axon node {batch, num_nodes, num_nodes}
  • output_dim - Output feature dimension

Options

  • :name - Layer name prefix (default: "gcn")
  • :activation - Activation function (default: :relu)
  • :dropout - Dropout rate (default: 0.0)
  • :add_self_loops - Add self-loops to adjacency (default: true)

Returns

Axon node with shape {batch, num_nodes, output_dim}.