Medusa: Multi-Head Speculative Decoding for 2-3x inference speedup.
Medusa attaches K lightweight "draft heads" to an existing LM backbone. Each head is a small MLP that predicts one future token position from the last hidden state, sharing the base model's vocabulary (embedding table). At inference the K heads generate a tree of candidate continuations, which are verified in a single forward pass via "tree attention". Tokens are accepted greedily along the tree path that agrees with the base model's distribution.
Motivation
Standard speculative decoding (e.g. SpeculativeDecoding) requires a
separate small draft model. Medusa eliminates this overhead by co-training
K heads that are already aligned with the base model's hidden states.
Because all K candidate tokens for all positions are verified in one
batched forward pass, the wall-clock cost grows sub-linearly with sequence
length.
Architecture
Base model (frozen or fine-tuned)
|
last hidden state [batch, hidden_dim]
|
+----------- K Medusa heads (parallel) -----------+
| Head k: dense(hidden_dim) -> SiLU -> dense(vocab_size) |
+--------------------------------------------------+
| | |
head_logits_1 head_logits_2 ... head_logits_K
|
build_tree_candidates/2
|
tree of candidate token seqs (top-k per head, combined)
|
tree_decoding_mask/1 ← causal attention mask for tree positions
|
one base-model forward pass (tree attention)
|
accept/reject each candidate pathUsage
model = Medusa.build(
base_hidden_dim: 256,
vocab_size: 32_000,
num_medusa_heads: 4
)
# At inference: generate candidates then verify
{cands, tree} = Medusa.build_tree_candidates(head_logits, top_k: 5)
mask = Medusa.tree_decoding_mask(tree)References
- Cai et al., "Medusa: Simple Framework for Accelerating LLM Inference with Multiple Decoding Heads" (2024) — https://arxiv.org/abs/2401.10774
Summary
Functions
Build Medusa draft heads as an Axon container.
Generate a tree of candidate token continuations from K head logits.
Build K Medusa head logits from hidden states.
Get output size (vocab_size passed through heads).
Build a tree attention mask for verifying all candidates in one pass.
Types
@type build_opt() :: {:base_hidden_dim, pos_integer()} | {:vocab_size, pos_integer()} | {:num_medusa_heads, pos_integer()} | {:medusa_num_layers, pos_integer()}
Options for build/1.
Functions
Build Medusa draft heads as an Axon container.
Returns an Axon.container map with keys :head_1 through :head_K,
each shaped [batch, vocab_size]. The input is the last hidden state
from the base model.
Options
:base_hidden_dim- Hidden dimension of the base model (required):vocab_size- Vocabulary size, shared with base model (required):num_medusa_heads- Number of speculative heads K (default: 4):medusa_num_layers- Dense layers per head, 1 = simple linear (default: 1)
Returns
An Axon.container map %{head_1: ..., ..., head_K: ...}, each
[batch, vocab_size].
@spec build_tree_candidates( %{required(atom()) => Nx.Tensor.t()}, keyword() ) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Generate a tree of candidate token continuations from K head logits.
Each head independently picks its top-top_k tokens. The candidates are
combined into a flat list of token sequences [1..K] for tree verification.
In the full Medusa algorithm these form a Cartesian-product tree; here we
return:
candidates— tensor[num_candidates, K]of token ID sequences, where each row is one path of length K through the tree.tree_indices— 1-D integer tensor of lengthnum_candidatesgiving each candidate's index into the flattened token tree (used to build the mask).
The simplest tree structure takes the top-k tokens from head 1, top-k from
head 2, ..., and enumerates all top_k^K combinations (clamped to a
maximum of top_k * K candidates for efficiency).
Parameters
head_logits- Map%{head_1: tensor, ..., head_K: tensor}, each[batch, vocab_size]. Only the first batch element is used.
Options
:top_k- Number of top tokens to consider per head (default: 5)
Returns
{candidates, tree_indices} where candidates is [num_cands, K] and
tree_indices is [num_cands].
Build K Medusa head logits from hidden states.
A helper that runs all K heads on hidden_states and returns a list of
logit tensors (each [batch, vocab_size]). Useful when composing Medusa
heads into a larger model graph.
Parameters
hidden_states- Axon node[batch, hidden_dim]opts- Same asbuild/1
Returns
Axon.container map with keys :head_1..:head_K.
@spec output_size(keyword()) :: pos_integer()
Get output size (vocab_size passed through heads).
@spec tree_decoding_mask(Nx.Tensor.t()) :: Nx.Tensor.t()
Build a tree attention mask for verifying all candidates in one pass.
Given tree candidates [num_cands, K], each candidate is a length-K path.
The mask ensures that when the base model processes all K token positions
for all num_cands candidates, each position can only attend to its
ancestors in the tree (causal within each path, no cross-path attention).
Returns a boolean mask [num_cands * K, num_cands * K] where entry
[i, j] is true iff position j is an ancestor of position i
(including self).
Parameters
tree_candidates- Token candidate tensor[num_cands, K]
Returns
Boolean mask [num_cands * K, num_cands * K].