Edifice.Transformer.MultiTokenPrediction (Edifice v0.2.0)

Copy Markdown View Source

Multi-Token Prediction (MTP) — predict multiple future tokens simultaneously.

Wraps a backbone transformer (DecoderOnly by default) with multiple independent prediction heads. Each head projects the backbone's hidden states to vocabulary logits for a different future position.

Key Innovation: Parallel Next-Token Heads

Instead of predicting only the next token, MTP attaches N independent dense layers to the backbone output, each predicting a different future position. This provides richer training signal and enables speculative decoding at inference time.

Architecture

Input [batch, seq_len, embed_dim]
      |
Backbone (output_mode: :all)
      |
[batch, seq_len, hidden_size]
      |
+-- Head 1: dense(vocab_size) -> pred_1 [batch, seq_len, vocab_size]
+-- Head 2: dense(vocab_size) -> pred_2 [batch, seq_len, vocab_size]
+-- ...
+-- Head N: dense(vocab_size) -> pred_N [batch, seq_len, vocab_size]
      |
Axon.container(%{pred_1: h1, pred_2: h2, ..., pred_N: hN})

Usage

model = MultiTokenPrediction.build(
  embed_dim: 256,
  vocab_size: 32000,
  num_predictions: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a Multi-Token Prediction model.

Get the output size of the model (hidden_size of backbone).

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:vocab_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_kv_heads, pos_integer()}
  | {:num_predictions, pos_integer()}
  | {:dropout, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Multi-Token Prediction model.

Options

  • :embed_dim - Input embedding dimension (required)
  • :vocab_size - Vocabulary size for each prediction head (required)
  • :hidden_size - Backbone hidden dimension (default: 256)
  • :num_layers - Number of backbone transformer layers (default: 4)
  • :num_heads - Number of attention heads (default: 4)
  • :num_kv_heads - Number of key/value heads for GQA (default: 2)
  • :num_predictions - Number of future tokens to predict (default: 4)
  • :dropout - Dropout rate (default: 0.1)

Returns

An Axon.container with keys :pred_1 through :pred_N, each shaped [batch, seq_len, vocab_size].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the model (hidden_size of backbone).