# `Edifice.Inference.Medusa`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/inference/medusa.ex#L1)

Medusa: Multi-Head Speculative Decoding for 2-3x inference speedup.

Medusa attaches K lightweight "draft heads" to an existing LM backbone.
Each head is a small MLP that predicts one future token position from the
last hidden state, sharing the base model's vocabulary (embedding table).
At inference the K heads generate a tree of candidate continuations, which
are verified in a single forward pass via "tree attention". Tokens are
accepted greedily along the tree path that agrees with the base model's
distribution.

## Motivation

Standard speculative decoding (e.g. `SpeculativeDecoding`) requires a
separate small draft model. Medusa eliminates this overhead by co-training
K heads that are already aligned with the base model's hidden states.
Because all K candidate tokens for all positions are verified in one
batched forward pass, the wall-clock cost grows sub-linearly with sequence
length.

## Architecture

```
Base model (frozen or fine-tuned)
      |
last hidden state [batch, hidden_dim]
      |
+----------- K Medusa heads (parallel) -----------+
| Head k: dense(hidden_dim) -> SiLU -> dense(vocab_size) |
+--------------------------------------------------+
      |                |             |
head_logits_1   head_logits_2  ...  head_logits_K
      |
build_tree_candidates/2
      |
tree of candidate token seqs (top-k per head, combined)
      |
tree_decoding_mask/1  ← causal attention mask for tree positions
      |
one base-model forward pass (tree attention)
      |
accept/reject each candidate path
```

## Usage

    model = Medusa.build(
      base_hidden_dim: 256,
      vocab_size: 32_000,
      num_medusa_heads: 4
    )

    # At inference: generate candidates then verify
    {cands, tree} = Medusa.build_tree_candidates(head_logits, top_k: 5)
    mask = Medusa.tree_decoding_mask(tree)

## References

- Cai et al., "Medusa: Simple Framework for Accelerating LLM Inference
  with Multiple Decoding Heads" (2024) — https://arxiv.org/abs/2401.10774

# `build_opt`

```elixir
@type build_opt() ::
  {:base_hidden_dim, pos_integer()}
  | {:vocab_size, pos_integer()}
  | {:num_medusa_heads, pos_integer()}
  | {:medusa_num_layers, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build Medusa draft heads as an Axon container.

Returns an `Axon.container` map with keys `:head_1` through `:head_K`,
each shaped `[batch, vocab_size]`.  The input is the last hidden state
from the base model.

## Options

  - `:base_hidden_dim` - Hidden dimension of the base model (required)
  - `:vocab_size` - Vocabulary size, shared with base model (required)
  - `:num_medusa_heads` - Number of speculative heads K (default: 4)
  - `:medusa_num_layers` - Dense layers per head, 1 = simple linear (default: 1)

## Returns

  An `Axon.container` map `%{head_1: ..., ..., head_K: ...}`, each
  `[batch, vocab_size]`.

# `build_tree_candidates`

```elixir
@spec build_tree_candidates(
  %{required(atom()) =&gt; Nx.Tensor.t()},
  keyword()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Generate a tree of candidate token continuations from K head logits.

Each head independently picks its top-`top_k` tokens. The candidates are
combined into a flat list of token sequences `[1..K]` for tree verification.
In the full Medusa algorithm these form a Cartesian-product tree; here we
return:

- `candidates` — tensor `[num_candidates, K]` of token ID sequences, where
  each row is one path of length K through the tree.
- `tree_indices` — 1-D integer tensor of length `num_candidates` giving each
  candidate's index into the flattened token tree (used to build the mask).

The simplest tree structure takes the top-k tokens from head 1, top-k from
head 2, ..., and enumerates all `top_k^K` combinations (clamped to a
maximum of `top_k * K` candidates for efficiency).

## Parameters

  - `head_logits` - Map `%{head_1: tensor, ..., head_K: tensor}`, each
    `[batch, vocab_size]`. Only the first batch element is used.

## Options

  - `:top_k` - Number of top tokens to consider per head (default: 5)

## Returns

  `{candidates, tree_indices}` where `candidates` is `[num_cands, K]` and
  `tree_indices` is `[num_cands]`.

# `medusa_heads`

```elixir
@spec medusa_heads(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build K Medusa head logits from hidden states.

A helper that runs all K heads on `hidden_states` and returns a list of
logit tensors (each `[batch, vocab_size]`).  Useful when composing Medusa
heads into a larger model graph.

## Parameters

  - `hidden_states` - Axon node `[batch, hidden_dim]`
  - `opts` - Same as `build/1`

## Returns

  `Axon.container` map with keys `:head_1`..`:head_K`.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get output size (vocab_size passed through heads).

# `tree_decoding_mask`

```elixir
@spec tree_decoding_mask(Nx.Tensor.t()) :: Nx.Tensor.t()
```

Build a tree attention mask for verifying all candidates in one pass.

Given tree candidates `[num_cands, K]`, each candidate is a length-K path.
The mask ensures that when the base model processes all K token positions
for all `num_cands` candidates, each position can only attend to its
ancestors in the tree (causal within each path, no cross-path attention).

Returns a boolean mask `[num_cands * K, num_cands * K]` where entry
`[i, j]` is `true` iff position `j` is an ancestor of position `i`
(including self).

## Parameters

  - `tree_candidates` - Token candidate tensor `[num_cands, K]`

## Returns

  Boolean mask `[num_cands * K, num_cands * K]`.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
