Edifice.Recurrent.TTT (Edifice v0.2.0)

Copy Markdown View Source

Test-Time Training (TTT) Layers.

Implements TTT layers from "Learning to (Learn at Test Time): RNNs with Expressive Hidden States" (Sun et al., 2024). In TTT, the hidden state is itself a model (a linear layer or small MLP) that is updated via a self-supervised gradient step at each token.

Key Innovations

  • Hidden state IS a model: Instead of a vector, the hidden state is the weight matrix of a small inner model
  • Self-supervised updates: At each step, the inner model does a gradient step on a reconstruction loss
  • Equivalent to linear attention: TTT-Linear is mathematically equivalent to linear attention with the delta rule when the inner model is linear

Paper-Faithful Implementation

Follows the official TTT-Linear implementation (ttt-lm-pytorch) with these key stability mechanisms:

  1. W_0 ~ N(0, 0.02): Small initialization keeps early predictions near zero, preventing gradient explosion in the first steps.
  2. eta / head_dim scaling: Inner learning rate is scaled by 1/d (d=inner_size), keeping weight updates small. Without this, eta in [0,1] is ~64x too large.
  3. Inner LayerNorm: Learnable LayerNorm on inner model predictions before computing reconstruction error. Prevents prediction magnitudes from drifting.
  4. Output gating: Sigmoid gate on output (like SwiGLU) for smoother gradients.

Equations (TTT-Linear)

# Project inputs
q_t = W_q x_t                          # Query
k_t = W_k x_t                          # Key
v_t = W_v x_t                          # Value (reconstruction target)
eta_t = sigmoid(W_eta x_t) / d         # Learning rate gate (scaled by 1/head_dim)

# Inner model forward + LayerNorm
pred_t = LayerNorm(W_{t-1} @ k_t)

# Self-supervised gradient update
error_t = pred_t - v_t
grad_W = error_t @ k_t^T
W_t = W_{t-1} - eta_t * grad_W

# Gated output using updated model
o_t = W_t @ q_t * sigmoid(gate_t)

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
[Input Projection] -> hidden_size
      |
      v
+--------------------------------------+
|        TTT Layer                     |
|  Project to Q, K, V, eta, gate       |
|  For each timestep:                  |
|    pred = LayerNorm(W @ k)           |
|    error = pred - v                  |
|    W -= (eta/d) * error * k^T        |
|    output = (W @ q) * sigmoid(gate)  |
+--------------------------------------+
      | (repeat num_layers)
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]

Usage

model = TTT.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  inner_size: 64,
  dropout: 0.1
)

References

Summary

Types

Options for build/1.

Functions

Build a TTT model for sequence processing.

Default dropout rate

Default hidden dimension

Default inner model dimension (key/value size)

Default number of layers

Get the output size of a TTT model.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:inner_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:output_gate, boolean()}
  | {:seq_len, pos_integer()}
  | {:variant, :linear | :mlp}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a TTT model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :inner_size - Inner model key/value dimension (default: 64)
  • :num_layers - Number of TTT layers (default: 4)
  • :dropout - Dropout rate between layers (default: 0.1)
  • :window_size - Expected sequence length (default: 60)
  • :variant - Inner model variant: :linear (default) or :mlp. The :mlp variant applies SiLU activation to keys and queries before the inner model, making the hidden state a 2-layer MLP instead of a single linear layer.
  • :output_gate - Apply sigmoid output gate (default: true). Provides smoother gradients by gating the TTT output before the residual.

Returns

An Axon model that processes sequences and outputs the last hidden state.

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_inner_size()

@spec default_inner_size() :: pos_integer()

Default inner model dimension (key/value size)

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a TTT model.