# `Edifice.Recurrent.TTT`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/recurrent/ttt.ex#L1)

Test-Time Training (TTT) Layers.

Implements TTT layers from "Learning to (Learn at Test Time): RNNs with
Expressive Hidden States" (Sun et al., 2024). In TTT, the hidden state
is itself a model (a linear layer or small MLP) that is updated via a
self-supervised gradient step at each token.

## Key Innovations

- **Hidden state IS a model**: Instead of a vector, the hidden state is
  the weight matrix of a small inner model
- **Self-supervised updates**: At each step, the inner model does a gradient
  step on a reconstruction loss
- **Equivalent to linear attention**: TTT-Linear is mathematically equivalent
  to linear attention with the delta rule when the inner model is linear

## Paper-Faithful Implementation

Follows the official TTT-Linear implementation (ttt-lm-pytorch) with these
key stability mechanisms:

1. **W_0 ~ N(0, 0.02)**: Small initialization keeps early predictions near zero,
   preventing gradient explosion in the first steps.
2. **eta / head_dim scaling**: Inner learning rate is scaled by 1/d (d=inner_size),
   keeping weight updates small. Without this, eta in [0,1] is ~64x too large.
3. **Inner LayerNorm**: Learnable LayerNorm on inner model predictions before
   computing reconstruction error. Prevents prediction magnitudes from drifting.
4. **Output gating**: Sigmoid gate on output (like SwiGLU) for smoother gradients.

## Equations (TTT-Linear)

```
# Project inputs
q_t = W_q x_t                          # Query
k_t = W_k x_t                          # Key
v_t = W_v x_t                          # Value (reconstruction target)
eta_t = sigmoid(W_eta x_t) / d         # Learning rate gate (scaled by 1/head_dim)

# Inner model forward + LayerNorm
pred_t = LayerNorm(W_{t-1} @ k_t)

# Self-supervised gradient update
error_t = pred_t - v_t
grad_W = error_t @ k_t^T
W_t = W_{t-1} - eta_t * grad_W

# Gated output using updated model
o_t = W_t @ q_t * sigmoid(gate_t)
```

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
[Input Projection] -> hidden_size
      |
      v
+--------------------------------------+
|        TTT Layer                     |
|  Project to Q, K, V, eta, gate       |
|  For each timestep:                  |
|    pred = LayerNorm(W @ k)           |
|    error = pred - v                  |
|    W -= (eta/d) * error * k^T        |
|    output = (W @ q) * sigmoid(gate)  |
+--------------------------------------+
      | (repeat num_layers)
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]
```

## Usage

    model = TTT.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 4,
      inner_size: 64,
      dropout: 0.1
    )

## References
- Paper: https://arxiv.org/abs/2407.04620

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:inner_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:output_gate, boolean()}
  | {:seq_len, pos_integer()}
  | {:variant, :linear | :mlp}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a TTT model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:inner_size` - Inner model key/value dimension (default: 64)
  - `:num_layers` - Number of TTT layers (default: 4)
  - `:dropout` - Dropout rate between layers (default: 0.1)
  - `:window_size` - Expected sequence length (default: 60)
  - `:variant` - Inner model variant: `:linear` (default) or `:mlp`.
    The `:mlp` variant applies SiLU activation to keys and queries before
    the inner model, making the hidden state a 2-layer MLP instead of
    a single linear layer.
  - `:output_gate` - Apply sigmoid output gate (default: true). Provides
    smoother gradients by gating the TTT output before the residual.

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `default_dropout`

```elixir
@spec default_dropout() :: float()
```

Default dropout rate

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension

# `default_inner_size`

```elixir
@spec default_inner_size() :: pos_integer()
```

Default inner model dimension (key/value size)

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of layers

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a TTT model.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
