# `Edifice.Graph.MessagePassing`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/graph/message_passing.ex#L1)

Generic Message Passing Neural Network (MPNN) framework.

Implements the message passing paradigm from "Neural Message Passing for
Quantum Chemistry" (Gilmer et al., 2017). This module provides the building
blocks for constructing graph neural networks where information propagates
along edges.

## Architecture

```
Node Features [batch, num_nodes, feature_dim]
Adjacency     [batch, num_nodes, num_nodes]
      |
      v
+----------------------------+
| For each edge (i,j):       |
|   m_ij = msg_fn(h_i, h_j) |
+----------------------------+
      |
      v
+----------------------------+
| Aggregate per node:        |
|   M_i = AGG({m_ij : j->i}) |
+----------------------------+
      |
      v
+----------------------------+
| Update node features:      |
|   h_i' = update(h_i, M_i) |
+----------------------------+
      |
      v
Updated Node Features [batch, num_nodes, feature_dim']
```

## Graph Representation

Graphs are represented as dense adjacency matrices since Nx does not support
sparse tensors. For large graphs, consider batching subgraphs.

- Adjacency: `{batch, num_nodes, num_nodes}` - binary or weighted
- Node features: `{batch, num_nodes, feature_dim}`

## Usage

    # Generic message passing step
    nodes = Axon.input("nodes", shape: {nil, 10, 64})
    adj = Axon.input("adjacency", shape: {nil, 10, 10})

    updated = MessagePassing.message_passing_layer(nodes, adj, 64,
      name: "mpnn_1", activation: :relu)

    # Graph-level output via pooling
    graph_repr = MessagePassing.global_pool(updated, :mean)

# `aggregation`

```elixir
@type aggregation() :: :sum | :mean | :max
```

# `aggregate`

```elixir
@spec aggregate(Axon.t(), Axon.t(), aggregation()) :: Axon.t()
```

Aggregate messages from neighbors using the specified method.

This is a standalone aggregation function that operates on pre-computed
messages. For most use cases, `message_passing_layer/4` handles aggregation
internally.

## Parameters

- `messages` - Message tensor `{batch, num_nodes, num_neighbors, msg_dim}`
- `adjacency` - Adjacency matrix `{batch, num_nodes, num_nodes}`
- `mode` - Aggregation mode: :sum, :mean, or :max

## Returns

Aggregated messages `{batch, num_nodes, msg_dim}`.

# `global_pool`

```elixir
@spec global_pool(Axon.t(), aggregation()) :: Axon.t()
```

Pool node features to a graph-level representation.

Reduces the node dimension by aggregating all node features in each graph
into a single vector. Essential for graph classification tasks.

## Parameters

- `node_features` - Axon node with shape `{batch, num_nodes, feature_dim}`
- `mode` - Pooling mode: :sum, :mean, or :max (default: :mean)

## Returns

Axon node with shape `{batch, feature_dim}`.

# `message_passing_layer`

```elixir
@spec message_passing_layer(Axon.t(), Axon.t(), pos_integer(), keyword()) :: Axon.t()
```

Generic message passing step.

For each edge (i, j) in the graph, computes a message from the features of
nodes i and j, aggregates incoming messages per node, and updates node
features via a learned transformation.

The message function concatenates sender and receiver features, then applies
a dense layer. The update function concatenates the node's current features
with aggregated messages, then applies a dense layer.

## Parameters

- `nodes` - Node features Axon node `{batch, num_nodes, feature_dim}`
- `adjacency` - Adjacency matrix Axon node `{batch, num_nodes, num_nodes}`
- `output_dim` - Output feature dimension per node
- `opts` - Options

## Options

- `:name` - Layer name prefix (default: "mpnn")
- `:activation` - Activation function (default: :relu)
- `:aggregation` - Message aggregation: :sum, :mean, :max (default: :sum)
- `:dropout` - Dropout rate (default: 0.0)

## Returns

Axon node with shape `{batch, num_nodes, output_dim}`.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
