Test-Time Training (TTT) Layers.
Implements TTT layers from "Learning to (Learn at Test Time): RNNs with Expressive Hidden States" (Sun et al., 2024). In TTT, the hidden state is itself a model (a linear layer or small MLP) that is updated via a self-supervised gradient step at each token.
Key Innovations
- Hidden state IS a model: Instead of a vector, the hidden state is the weight matrix of a small inner model
- Self-supervised updates: At each step, the inner model does a gradient step on a reconstruction loss
- Equivalent to linear attention: TTT-Linear is mathematically equivalent to linear attention with the delta rule when the inner model is linear
Paper-Faithful Implementation
Follows the official TTT-Linear implementation (ttt-lm-pytorch) with these key stability mechanisms:
- W_0 ~ N(0, 0.02): Small initialization keeps early predictions near zero, preventing gradient explosion in the first steps.
- eta / head_dim scaling: Inner learning rate is scaled by 1/d (d=inner_size), keeping weight updates small. Without this, eta in [0,1] is ~64x too large.
- Inner LayerNorm: Learnable LayerNorm on inner model predictions before computing reconstruction error. Prevents prediction magnitudes from drifting.
- Output gating: Sigmoid gate on output (like SwiGLU) for smoother gradients.
Equations (TTT-Linear)
# Project inputs
q_t = W_q x_t # Query
k_t = W_k x_t # Key
v_t = W_v x_t # Value (reconstruction target)
eta_t = sigmoid(W_eta x_t) / d # Learning rate gate (scaled by 1/head_dim)
# Inner model forward + LayerNorm
pred_t = LayerNorm(W_{t-1} @ k_t)
# Self-supervised gradient update
error_t = pred_t - v_t
grad_W = error_t @ k_t^T
W_t = W_{t-1} - eta_t * grad_W
# Gated output using updated model
o_t = W_t @ q_t * sigmoid(gate_t)Architecture
Input [batch, seq_len, embed_dim]
|
v
[Input Projection] -> hidden_size
|
v
+--------------------------------------+
| TTT Layer |
| Project to Q, K, V, eta, gate |
| For each timestep: |
| pred = LayerNorm(W @ k) |
| error = pred - v |
| W -= (eta/d) * error * k^T |
| output = (W @ q) * sigmoid(gate) |
+--------------------------------------+
| (repeat num_layers)
v
[Layer Norm] -> [Last Timestep]
|
v
Output [batch, hidden_size]Usage
model = TTT.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
inner_size: 64,
dropout: 0.1
)References
Summary
Functions
Build a TTT model for sequence processing.
Default dropout rate
Default hidden dimension
Default inner model dimension (key/value size)
Default number of layers
Get the output size of a TTT model.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:inner_size, pos_integer()} | {:num_layers, pos_integer()} | {:output_gate, boolean()} | {:seq_len, pos_integer()} | {:variant, :linear | :mlp} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a TTT model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):inner_size- Inner model key/value dimension (default: 64):num_layers- Number of TTT layers (default: 4):dropout- Dropout rate between layers (default: 0.1):window_size- Expected sequence length (default: 60):variant- Inner model variant::linear(default) or:mlp. The:mlpvariant applies SiLU activation to keys and queries before the inner model, making the hidden state a 2-layer MLP instead of a single linear layer.:output_gate- Apply sigmoid output gate (default: true). Provides smoother gradients by gating the TTT output before the residual.
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec default_dropout() :: float()
Default dropout rate
@spec default_inner_size() :: pos_integer()
Default inner model dimension (key/value size)
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a TTT model.