# `Edifice.Graph.EGNN`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/graph/egnn.ex#L1)

E(n) Equivariant Graph Neural Network.

<!-- verified: true, date: 2026-02-23 -->

EGNN processes graphs with 3D (or n-D) coordinates while preserving Euclidean
symmetries: rotation, translation, and reflection. This makes it ideal for
molecular simulations, protein structure prediction, and physical systems
where the laws are invariant to these transformations.

## Architecture

```
Node Features [batch, num_nodes, node_dim]
Coordinates   [batch, num_nodes, coord_dim]  (typically 3D positions)
Edge Index    [batch, num_edges, 2]          (source, target pairs)
Edge Features [batch, num_edges, edge_dim]   (optional)
      |
      v
+--------------------------------------+
| EGNN Layer 1:                        |
|   1. Compute squared distances       |
|   2. Edge message: φ_e(h_i, h_j, d²) |
|   3. Coordinate update: equivariant  |
|   4. Feature update: invariant       |
+--------------------------------------+
      |  (repeat N times)
      v
Updated Node Features [batch, num_nodes, hidden_dim]
Updated Coordinates   [batch, num_nodes, coord_dim]
```

## Key Equations

For each layer, given node features h and coordinates x:

1. **Edge message**: `m_ij = φ_e(h_i, h_j, ||x_i - x_j||², a_ij)`
   - Uses squared distance (invariant scalar), not raw positions
   - φ_e is a small MLP

2. **Coordinate update**: `x_i' = x_i + Σ_j (x_i - x_j) · φ_x(m_ij)`
   - The direction (x_i - x_j) is equivariant
   - Scaling by φ_x(m_ij) preserves equivariance

3. **Feature update**: `h_i' = φ_h(h_i, Σ_j m_ij)`
   - Aggregated messages are invariant
   - φ_h is a small MLP

## Usage

    model = EGNN.build(
      in_node_features: 16,
      in_edge_features: 4,
      hidden_dim: 64,
      num_layers: 4,
      out_features: 32
    )

## References

- Satorras et al., "E(n) Equivariant Graph Neural Networks" (NeurIPS 2021)
- https://arxiv.org/abs/2102.09844

# `build_opt`

```elixir
@type build_opt() ::
  {:coord_dim, pos_integer()}
  | {:hidden_dim, pos_integer()}
  | {:in_edge_features, non_neg_integer()}
  | {:in_node_features, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:out_features, pos_integer()}
  | {:update_coords, boolean()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build an EGNN model.

## Options

  - `:in_node_features` - Input node feature dimension (required)
  - `:in_edge_features` - Input edge feature dimension (default: 0)
  - `:hidden_dim` - Hidden dimension (default: 64)
  - `:num_layers` - Number of EGNN layers (default: 4)
  - `:out_features` - Output feature dimension (default: hidden_dim)
  - `:coord_dim` - Coordinate dimension, e.g., 3 for 3D (default: 3)
  - `:update_coords` - Whether to update coordinates (default: true)

## Returns

  An Axon model with inputs:
  - "nodes": Node features [batch, num_nodes, in_node_features]
  - "coords": Node coordinates [batch, num_nodes, coord_dim]
  - "edge_index": Edge indices [batch, num_edges, 2]
  - "edge_features": Edge features [batch, num_edges, in_edge_features] (optional)

  Returns a container with `node_features` and `coords`.

# `egnn_layer`

```elixir
@spec egnn_layer(Axon.t(), Axon.t(), Axon.t(), Axon.t() | nil, keyword()) ::
  {Axon.t(), Axon.t()}
```

Single E(n)-equivariant graph neural network layer.

## Parameters

  - `node_feats` - Node features [batch, num_nodes, hidden_dim]
  - `coords` - Node coordinates [batch, num_nodes, coord_dim]
  - `edge_index` - Edge indices [batch, num_edges, 2]
  - `edge_features` - Optional edge features [batch, num_edges, edge_dim]

## Options

  - `:hidden_dim` - Hidden dimension
  - `:in_edge_features` - Edge feature dimension
  - `:update_coords` - Whether to update coordinates
  - `:name` - Layer name prefix

## Returns

  Tuple of {updated_node_feats, updated_coords}.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output size of an EGNN model.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
