TTT-E2E: End-to-End Test-Time Training for Long Context.
Implements the TTT-E2E architecture from "End-to-End Test-Time Training for Long Context" (Stanford, NVIDIA, UC Berkeley, Dec 2025). Unlike the original TTT layers (which replace attention with self-supervised inner model updates), TTT-E2E keeps a standard transformer backbone and mutates ~25% of its MLP layers at inference time using end-to-end gradient descent.
Key Differences from TTT-Linear/TTT-MLP
| Aspect | TTT-Linear/MLP | TTT-E2E |
|---|---|---|
| Where TTT happens | Custom layer replacing attention | Updates existing MLP in last 1/4 blocks |
| Inner loss | Layer-wise reconstruction | End-to-end next-token prediction |
| Architecture | Custom TTT layer | Standard transformer + dual MLP |
| Training | Standard pretraining | Meta-learning (bilevel optimization) |
Architecture: Dual-MLP Blocks
In the last 1/4 of transformer blocks, each MLP sublayer is split into:
- Dynamic MLP: Updated via SGD at inference (stores document context)
- Static MLP: Frozen at inference (preserves pretrained knowledge)
Both MLPs receive the same input; their outputs are summed. This prevents catastrophic forgetting while allowing the model to adapt to new context.
Input [batch, seq_len, embed_dim]
|
v
+----------------------------------------------+
| Frozen Block 1..N*3/4 |
| LayerNorm -> SlidingWindowAttn -> Residual |
| LayerNorm -> MLP -> Residual |
+----------------------------------------------+
|
v
+----------------------------------------------+
| Mutable Block N*3/4+1..N |
| LayerNorm -> SlidingWindowAttn -> Residual |
| LayerNorm -> (DynamicMLP + StaticMLP) |
| -> Residual |
+----------------------------------------------+
|
v
[Layer Norm] -> [Last Timestep]
|
v
Output [batch, hidden_size]Inference Protocol
- Reset dynamic MLP weights to W0 at start of each document
- Process tokens in mini-batches of size b (default: 1024)
- After each mini-batch: compute next-token loss, backprop to dynamic MLP params only, apply SGD step
- Dynamic MLPs accumulate context throughout the document
Usage
model = TTTE2E.build(
embed_dim: 256,
hidden_size: 256,
num_layers: 12, # Last 3 blocks will have dual MLPs
num_heads: 4,
window_size: 60
)References
Summary
Functions
Build a TTT-E2E model.
Get the layer pattern showing which blocks are mutable.
Get the names of mutable (dynamic MLP) parameters for a built model.
Get the output size of a TTT-E2E model.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:head_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:mlp_ratio, pos_integer()} | {:mutable_fraction, float()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a TTT-E2E model.
Options
Architecture:
:embed_dim- Input embedding dimension (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Total number of transformer blocks (default: 12):num_heads- Number of attention heads (default: 4):head_dim- Dimension per attention head (default: 64):mlp_ratio- MLP expansion ratio (default: 4)
TTT-specific:
:mutable_fraction- Fraction of blocks with dual MLPs (default: 0.25). Mutable blocks are placed at the end of the stack.
General:
:dropout- Dropout rate (default: 0.1):window_size- Sliding window attention size (default: 60):seq_len- Fixed sequence length for JIT (default: window_size)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Get the layer pattern showing which blocks are mutable.
Example
iex> TTTE2E.layer_pattern(num_layers: 8, mutable_fraction: 0.25)
[:frozen, :frozen, :frozen, :frozen, :frozen, :frozen, :mutable, :mutable]
Get the names of mutable (dynamic MLP) parameters for a built model.
These are the parameters that should be updated via SGD at inference time. Use this to partition parameters into frozen and mutable sets.
Options
:num_layers- Total layers (default: 12):mutable_fraction- Fraction of mutable blocks (default: 0.25)
Returns
List of parameter name prefixes for dynamic MLP layers.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a TTT-E2E model.