# `Edifice.Recurrent.TTTE2E`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/recurrent/ttt_e2e.ex#L1)

TTT-E2E: End-to-End Test-Time Training for Long Context.

Implements the TTT-E2E architecture from "End-to-End Test-Time Training
for Long Context" (Stanford, NVIDIA, UC Berkeley, Dec 2025). Unlike the
original TTT layers (which replace attention with self-supervised inner
model updates), TTT-E2E keeps a standard transformer backbone and mutates
~25% of its MLP layers at inference time using end-to-end gradient descent.

## Key Differences from TTT-Linear/TTT-MLP

| Aspect | TTT-Linear/MLP | TTT-E2E |
|--------|---------------|---------|
| Where TTT happens | Custom layer replacing attention | Updates existing MLP in last 1/4 blocks |
| Inner loss | Layer-wise reconstruction | End-to-end next-token prediction |
| Architecture | Custom TTT layer | Standard transformer + dual MLP |
| Training | Standard pretraining | Meta-learning (bilevel optimization) |

## Architecture: Dual-MLP Blocks

In the last 1/4 of transformer blocks, each MLP sublayer is split into:

- **Dynamic MLP**: Updated via SGD at inference (stores document context)
- **Static MLP**: Frozen at inference (preserves pretrained knowledge)

Both MLPs receive the same input; their outputs are summed. This prevents
catastrophic forgetting while allowing the model to adapt to new context.

```
Input [batch, seq_len, embed_dim]
      |
      v
+----------------------------------------------+
|  Frozen Block 1..N*3/4                        |
|    LayerNorm -> SlidingWindowAttn -> Residual  |
|    LayerNorm -> MLP -> Residual                |
+----------------------------------------------+
      |
      v
+----------------------------------------------+
|  Mutable Block N*3/4+1..N                     |
|    LayerNorm -> SlidingWindowAttn -> Residual  |
|    LayerNorm -> (DynamicMLP + StaticMLP)       |
|    -> Residual                                 |
+----------------------------------------------+
      |
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]
```

## Inference Protocol

1. Reset dynamic MLP weights to W0 at start of each document
2. Process tokens in mini-batches of size b (default: 1024)
3. After each mini-batch: compute next-token loss, backprop to dynamic
   MLP params only, apply SGD step
4. Dynamic MLPs accumulate context throughout the document

## Usage

    model = TTTE2E.build(
      embed_dim: 256,
      hidden_size: 256,
      num_layers: 12,       # Last 3 blocks will have dual MLPs
      num_heads: 4,
      window_size: 60
    )

## References
- Paper: https://arxiv.org/abs/2512.23675
- Code: https://github.com/test-time-training/e2e

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:mlp_ratio, pos_integer()}
  | {:mutable_fraction, float()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a TTT-E2E model.

## Options

**Architecture:**
  - `:embed_dim` - Input embedding dimension (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Total number of transformer blocks (default: 12)
  - `:num_heads` - Number of attention heads (default: 4)
  - `:head_dim` - Dimension per attention head (default: 64)
  - `:mlp_ratio` - MLP expansion ratio (default: 4)

**TTT-specific:**
  - `:mutable_fraction` - Fraction of blocks with dual MLPs (default: 0.25).
    Mutable blocks are placed at the end of the stack.

**General:**
  - `:dropout` - Dropout rate (default: 0.1)
  - `:window_size` - Sliding window attention size (default: 60)
  - `:seq_len` - Fixed sequence length for JIT (default: window_size)

## Returns
  An Axon model that outputs [batch, hidden_size] from the last position.

# `layer_pattern`

```elixir
@spec layer_pattern(keyword()) :: [atom()]
```

Get the layer pattern showing which blocks are mutable.

## Example

    iex> TTTE2E.layer_pattern(num_layers: 8, mutable_fraction: 0.25)
    [:frozen, :frozen, :frozen, :frozen, :frozen, :frozen, :mutable, :mutable]

# `mutable_param_prefixes`

```elixir
@spec mutable_param_prefixes(keyword()) :: [String.t()]
```

Get the names of mutable (dynamic MLP) parameters for a built model.

These are the parameters that should be updated via SGD at inference time.
Use this to partition parameters into frozen and mutable sets.

## Options
  - `:num_layers` - Total layers (default: 12)
  - `:mutable_fraction` - Fraction of mutable blocks (default: 0.25)

## Returns
  List of parameter name prefixes for dynamic MLP layers.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a TTT-E2E model.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
