Speculative Decoding — accelerate autoregressive generation with draft+verify.
Coordinates a small "draft" model and a large "verifier" model to speed up inference. The draft model generates K candidate tokens cheaply, then the verifier checks them in a single forward pass. Accepted tokens skip the expensive verifier's autoregressive steps.
How It Works
- Draft model generates K candidate tokens autoregressively (fast)
- Verifier scores all K tokens in one forward pass (parallel)
- Accept longest prefix where draft and verifier agree
- Continue from first disagreement
Architecture
build/1 returns {draft_model, verifier_model}
Draft Model (small, fast):
[batch, seq_len, embed_dim] → DecoderOnly (few layers) → [batch, hidden]
Verifier Model (large, accurate):
[batch, seq_len, embed_dim] → DecoderOnly (many layers) → [batch, hidden]Usage
{draft, verifier} = SpeculativeDecoding.build(
embed_dim: 256,
draft_model_opts: [hidden_size: 128, num_layers: 2],
verifier_model_opts: [hidden_size: 512, num_layers: 8]
)
# At inference time:
accepted = SpeculativeDecoding.accept_reject(draft_tokens, verifier_tokens)References
- "Fast Inference from Transformers via Speculative Decoding" (Leviathan et al., 2023) — https://arxiv.org/abs/2211.17192
Summary
Functions
Accept-reject step: compare draft tokens to verifier tokens.
Build draft and verifier models for speculative decoding.
Get the output size (delegates to verifier model's hidden_size).
Types
Functions
@spec accept_reject(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Accept-reject step: compare draft tokens to verifier tokens.
Finds the longest accepted prefix where draft and verifier tokens match. Returns the number of accepted tokens (0 means first token mismatched).
Parameters
draft_tokens- Tensor of draft token IDs[K]or[batch, K]verifier_tokens- Tensor of verifier token IDs[K]or[batch, K]
Returns
Integer tensor with the number of accepted tokens per batch element.
Build draft and verifier models for speculative decoding.
Options
:embed_dim- Input embedding dimension (required):draft_model_type- Draft architecture (default::decoder_only):draft_model_opts- Options for draft model (default: small config):verifier_model_type- Verifier architecture (default::decoder_only):verifier_model_opts- Options for verifier model (default: larger config)
Returns
{draft_model, verifier_model} — a 2-tuple of Axon models.
@spec output_size(keyword()) :: pos_integer()
Get the output size (delegates to verifier model's hidden_size).