# `Edifice.Transformer.MultiTokenPrediction`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/transformer/multi_token_prediction.ex#L1)

Multi-Token Prediction (MTP) — predict multiple future tokens simultaneously.

Wraps a backbone transformer (DecoderOnly by default) with multiple
independent prediction heads. Each head projects the backbone's hidden
states to vocabulary logits for a different future position.

## Key Innovation: Parallel Next-Token Heads

Instead of predicting only the next token, MTP attaches N independent
dense layers to the backbone output, each predicting a different future
position. This provides richer training signal and enables speculative
decoding at inference time.

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
Backbone (output_mode: :all)
      |
[batch, seq_len, hidden_size]
      |
+-- Head 1: dense(vocab_size) -> pred_1 [batch, seq_len, vocab_size]
+-- Head 2: dense(vocab_size) -> pred_2 [batch, seq_len, vocab_size]
+-- ...
+-- Head N: dense(vocab_size) -> pred_N [batch, seq_len, vocab_size]
      |
Axon.container(%{pred_1: h1, pred_2: h2, ..., pred_N: hN})
```

## Usage

    model = MultiTokenPrediction.build(
      embed_dim: 256,
      vocab_size: 32000,
      num_predictions: 4
    )

## References

- "Better & Faster Large Language Models via Multi-token Prediction"
  (Gloeckle et al., 2024) — https://arxiv.org/abs/2404.19737

# `build_opt`

```elixir
@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:vocab_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_kv_heads, pos_integer()}
  | {:num_predictions, pos_integer()}
  | {:dropout, float()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Multi-Token Prediction model.

## Options

  - `:embed_dim` - Input embedding dimension (required)
  - `:vocab_size` - Vocabulary size for each prediction head (required)
  - `:hidden_size` - Backbone hidden dimension (default: 256)
  - `:num_layers` - Number of backbone transformer layers (default: 4)
  - `:num_heads` - Number of attention heads (default: 4)
  - `:num_kv_heads` - Number of key/value heads for GQA (default: 2)
  - `:num_predictions` - Number of future tokens to predict (default: 4)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns

  An `Axon.container` with keys `:pred_1` through `:pred_N`, each
  shaped `[batch, seq_len, vocab_size]`.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output size of the model (hidden_size of backbone).

---

*Consult [api-reference.md](api-reference.md) for complete listing*
