Edifice.Inference.Medusa (Edifice v0.2.0)

Copy Markdown View Source

Medusa: Multi-Head Speculative Decoding for 2-3x inference speedup.

Medusa attaches K lightweight "draft heads" to an existing LM backbone. Each head is a small MLP that predicts one future token position from the last hidden state, sharing the base model's vocabulary (embedding table). At inference the K heads generate a tree of candidate continuations, which are verified in a single forward pass via "tree attention". Tokens are accepted greedily along the tree path that agrees with the base model's distribution.

Motivation

Standard speculative decoding (e.g. SpeculativeDecoding) requires a separate small draft model. Medusa eliminates this overhead by co-training K heads that are already aligned with the base model's hidden states. Because all K candidate tokens for all positions are verified in one batched forward pass, the wall-clock cost grows sub-linearly with sequence length.

Architecture

Base model (frozen or fine-tuned)
      |
last hidden state [batch, hidden_dim]
      |
+----------- K Medusa heads (parallel) -----------+
| Head k: dense(hidden_dim) -> SiLU -> dense(vocab_size) |
+--------------------------------------------------+
      |                |             |
head_logits_1   head_logits_2  ...  head_logits_K
      |
build_tree_candidates/2
      |
tree of candidate token seqs (top-k per head, combined)
      |
tree_decoding_mask/1   causal attention mask for tree positions
      |
one base-model forward pass (tree attention)
      |
accept/reject each candidate path

Usage

model = Medusa.build(
  base_hidden_dim: 256,
  vocab_size: 32_000,
  num_medusa_heads: 4
)

# At inference: generate candidates then verify
{cands, tree} = Medusa.build_tree_candidates(head_logits, top_k: 5)
mask = Medusa.tree_decoding_mask(tree)

References

Summary

Types

Options for build/1.

Functions

Build Medusa draft heads as an Axon container.

Generate a tree of candidate token continuations from K head logits.

Build K Medusa head logits from hidden states.

Get output size (vocab_size passed through heads).

Build a tree attention mask for verifying all candidates in one pass.

Types

build_opt()

@type build_opt() ::
  {:base_hidden_dim, pos_integer()}
  | {:vocab_size, pos_integer()}
  | {:num_medusa_heads, pos_integer()}
  | {:medusa_num_layers, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build Medusa draft heads as an Axon container.

Returns an Axon.container map with keys :head_1 through :head_K, each shaped [batch, vocab_size]. The input is the last hidden state from the base model.

Options

  • :base_hidden_dim - Hidden dimension of the base model (required)
  • :vocab_size - Vocabulary size, shared with base model (required)
  • :num_medusa_heads - Number of speculative heads K (default: 4)
  • :medusa_num_layers - Dense layers per head, 1 = simple linear (default: 1)

Returns

An Axon.container map %{head_1: ..., ..., head_K: ...}, each [batch, vocab_size].

build_tree_candidates(head_logits, opts \\ [])

@spec build_tree_candidates(
  %{required(atom()) => Nx.Tensor.t()},
  keyword()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Generate a tree of candidate token continuations from K head logits.

Each head independently picks its top-top_k tokens. The candidates are combined into a flat list of token sequences [1..K] for tree verification. In the full Medusa algorithm these form a Cartesian-product tree; here we return:

  • candidates — tensor [num_candidates, K] of token ID sequences, where each row is one path of length K through the tree.
  • tree_indices — 1-D integer tensor of length num_candidates giving each candidate's index into the flattened token tree (used to build the mask).

The simplest tree structure takes the top-k tokens from head 1, top-k from head 2, ..., and enumerates all top_k^K combinations (clamped to a maximum of top_k * K candidates for efficiency).

Parameters

  • head_logits - Map %{head_1: tensor, ..., head_K: tensor}, each [batch, vocab_size]. Only the first batch element is used.

Options

  • :top_k - Number of top tokens to consider per head (default: 5)

Returns

{candidates, tree_indices} where candidates is [num_cands, K] and tree_indices is [num_cands].

medusa_heads(hidden_states, opts)

@spec medusa_heads(
  Axon.t(),
  keyword()
) :: Axon.t()

Build K Medusa head logits from hidden states.

A helper that runs all K heads on hidden_states and returns a list of logit tensors (each [batch, vocab_size]). Useful when composing Medusa heads into a larger model graph.

Parameters

  • hidden_states - Axon node [batch, hidden_dim]
  • opts - Same as build/1

Returns

Axon.container map with keys :head_1..:head_K.

output_size(opts)

@spec output_size(keyword()) :: pos_integer()

Get output size (vocab_size passed through heads).

tree_decoding_mask(tree_candidates)

@spec tree_decoding_mask(Nx.Tensor.t()) :: Nx.Tensor.t()

Build a tree attention mask for verifying all candidates in one pass.

Given tree candidates [num_cands, K], each candidate is a length-K path. The mask ensures that when the base model processes all K token positions for all num_cands candidates, each position can only attend to its ancestors in the tree (causal within each path, no cross-path attention).

Returns a boolean mask [num_cands * K, num_cands * K] where entry [i, j] is true iff position j is an ancestor of position i (including self).

Parameters

  • tree_candidates - Token candidate tensor [num_cands, K]

Returns

Boolean mask [num_cands * K, num_cands * K].