Multi-Token Prediction (MTP) — predict multiple future tokens simultaneously.
Wraps a backbone transformer (DecoderOnly by default) with multiple independent prediction heads. Each head projects the backbone's hidden states to vocabulary logits for a different future position.
Key Innovation: Parallel Next-Token Heads
Instead of predicting only the next token, MTP attaches N independent dense layers to the backbone output, each predicting a different future position. This provides richer training signal and enables speculative decoding at inference time.
Architecture
Input [batch, seq_len, embed_dim]
|
Backbone (output_mode: :all)
|
[batch, seq_len, hidden_size]
|
+-- Head 1: dense(vocab_size) -> pred_1 [batch, seq_len, vocab_size]
+-- Head 2: dense(vocab_size) -> pred_2 [batch, seq_len, vocab_size]
+-- ...
+-- Head N: dense(vocab_size) -> pred_N [batch, seq_len, vocab_size]
|
Axon.container(%{pred_1: h1, pred_2: h2, ..., pred_N: hN})Usage
model = MultiTokenPrediction.build(
embed_dim: 256,
vocab_size: 32000,
num_predictions: 4
)References
- "Better & Faster Large Language Models via Multi-token Prediction" (Gloeckle et al., 2024) — https://arxiv.org/abs/2404.19737
Summary
Functions
Build a Multi-Token Prediction model.
Get the output size of the model (hidden_size of backbone).
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:vocab_size, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:num_kv_heads, pos_integer()} | {:num_predictions, pos_integer()} | {:dropout, float()}
Options for build/1.
Functions
Build a Multi-Token Prediction model.
Options
:embed_dim- Input embedding dimension (required):vocab_size- Vocabulary size for each prediction head (required):hidden_size- Backbone hidden dimension (default: 256):num_layers- Number of backbone transformer layers (default: 4):num_heads- Number of attention heads (default: 4):num_kv_heads- Number of key/value heads for GQA (default: 2):num_predictions- Number of future tokens to predict (default: 4):dropout- Dropout rate (default: 0.1)
Returns
An Axon.container with keys :pred_1 through :pred_N, each
shaped [batch, seq_len, vocab_size].
@spec output_size(keyword()) :: pos_integer()
Get the output size of the model (hidden_size of backbone).