Edifice.Recurrent.TTTE2E (Edifice v0.2.0)

Copy Markdown View Source

TTT-E2E: End-to-End Test-Time Training for Long Context.

Implements the TTT-E2E architecture from "End-to-End Test-Time Training for Long Context" (Stanford, NVIDIA, UC Berkeley, Dec 2025). Unlike the original TTT layers (which replace attention with self-supervised inner model updates), TTT-E2E keeps a standard transformer backbone and mutates ~25% of its MLP layers at inference time using end-to-end gradient descent.

Key Differences from TTT-Linear/TTT-MLP

AspectTTT-Linear/MLPTTT-E2E
Where TTT happensCustom layer replacing attentionUpdates existing MLP in last 1/4 blocks
Inner lossLayer-wise reconstructionEnd-to-end next-token prediction
ArchitectureCustom TTT layerStandard transformer + dual MLP
TrainingStandard pretrainingMeta-learning (bilevel optimization)

Architecture: Dual-MLP Blocks

In the last 1/4 of transformer blocks, each MLP sublayer is split into:

  • Dynamic MLP: Updated via SGD at inference (stores document context)
  • Static MLP: Frozen at inference (preserves pretrained knowledge)

Both MLPs receive the same input; their outputs are summed. This prevents catastrophic forgetting while allowing the model to adapt to new context.

Input [batch, seq_len, embed_dim]
      |
      v
+----------------------------------------------+
|  Frozen Block 1..N*3/4                        |
|    LayerNorm -> SlidingWindowAttn -> Residual  |
|    LayerNorm -> MLP -> Residual                |
+----------------------------------------------+
      |
      v
+----------------------------------------------+
|  Mutable Block N*3/4+1..N                     |
|    LayerNorm -> SlidingWindowAttn -> Residual  |
|    LayerNorm -> (DynamicMLP + StaticMLP)       |
|    -> Residual                                 |
+----------------------------------------------+
      |
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]

Inference Protocol

  1. Reset dynamic MLP weights to W0 at start of each document
  2. Process tokens in mini-batches of size b (default: 1024)
  3. After each mini-batch: compute next-token loss, backprop to dynamic MLP params only, apply SGD step
  4. Dynamic MLPs accumulate context throughout the document

Usage

model = TTTE2E.build(
  embed_dim: 256,
  hidden_size: 256,
  num_layers: 12,       # Last 3 blocks will have dual MLPs
  num_heads: 4,
  window_size: 60
)

References

Summary

Types

Options for build/1.

Functions

Build a TTT-E2E model.

Get the layer pattern showing which blocks are mutable.

Get the names of mutable (dynamic MLP) parameters for a built model.

Get the output size of a TTT-E2E model.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:mlp_ratio, pos_integer()}
  | {:mutable_fraction, float()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a TTT-E2E model.

Options

Architecture:

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Total number of transformer blocks (default: 12)
  • :num_heads - Number of attention heads (default: 4)
  • :head_dim - Dimension per attention head (default: 64)
  • :mlp_ratio - MLP expansion ratio (default: 4)

TTT-specific:

  • :mutable_fraction - Fraction of blocks with dual MLPs (default: 0.25). Mutable blocks are placed at the end of the stack.

General:

  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Sliding window attention size (default: 60)
  • :seq_len - Fixed sequence length for JIT (default: window_size)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

layer_pattern(opts \\ [])

@spec layer_pattern(keyword()) :: [atom()]

Get the layer pattern showing which blocks are mutable.

Example

iex> TTTE2E.layer_pattern(num_layers: 8, mutable_fraction: 0.25)
[:frozen, :frozen, :frozen, :frozen, :frozen, :frozen, :mutable, :mutable]

mutable_param_prefixes(opts \\ [])

@spec mutable_param_prefixes(keyword()) :: [String.t()]

Get the names of mutable (dynamic MLP) parameters for a built model.

These are the parameters that should be updated via SGD at inference time. Use this to partition parameters into frozen and mutable sets.

Options

  • :num_layers - Total layers (default: 12)
  • :mutable_fraction - Fraction of mutable blocks (default: 0.25)

Returns

List of parameter name prefixes for dynamic MLP layers.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a TTT-E2E model.