Edifice.Transformer.DecoderOnly (Edifice v0.2.0)

Copy Markdown View Source

GPT-style decoder-only transformer with GQA + RoPE + SwiGLU + RMSNorm.

Combines modern LLM techniques into a single decoder-only transformer:

  • Grouped Query Attention (GQA) for efficient KV cache
  • Rotary Position Embeddings (RoPE) for position encoding
  • SwiGLU gated feed-forward network
  • RMSNorm for faster normalization

Attention Variants

The :attention_type option allows switching between attention mechanisms:

  • :gqa (default) — Grouped Query Attention with RoPE
  • :lightning — Lightning Attention (hybrid linear/softmax block attention)
  • :dual_chunk — Dual Chunk Attention (intra-chunk + inter-chunk for long contexts)

Architecture

Input [batch, seq_len, embed_dim]
      |
Input projection to hidden_size
      |
+------------------------------------+
|   Decoder Block (x num_layers)     |
|                                    |
|   RMSNorm -> Attention             |
|     (GQA / Lightning / DualChunk)  |
|     + RoPE on Q and K (GQA only)   |
|   -> Residual                      |
|   RMSNorm -> SwiGLU FFN            |
|   -> Residual                      |
+------------------------------------+
      |
Final LayerNorm
      |
Last timestep -> [batch, hidden_size]

Usage

# Default GQA attention
model = DecoderOnly.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 8,
  num_kv_heads: 2,
  num_layers: 6
)

# Lightning Attention for subquadratic complexity
model = DecoderOnly.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 8,
  num_layers: 6,
  attention_type: :lightning,
  block_size: 64
)

# Dual Chunk Attention for long contexts
model = DecoderOnly.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 8,
  num_layers: 6,
  attention_type: :dual_chunk,
  chunk_size: 64
)

References

  • GPT-2/3 decoder-only architecture (Radford et al., 2019; Brown et al., 2020)
  • LLaMA architecture combining GQA + RoPE + SwiGLU + RMSNorm (Touvron et al., 2023)
  • Lightning Attention-2 (Qin et al., 2024)
  • DeepSeek/Qwen2.5 Dual Chunk Attention (2024)

Summary

Types

Options for build/1.

Functions

Build a GPT-style decoder-only transformer model.

Get the output size of a decoder-only model.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_kv_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:attention_type, :gqa | :lightning | :dual_chunk}
  | {:block_size, pos_integer()}
  | {:chunk_size, pos_integer()}
  | {:use_rope, boolean()}
  | {:interleave_rope, boolean()}
  | {:yarn, boolean()}
  | {:yarn_scale, number()}
  | {:yarn_original_max_position, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a GPT-style decoder-only transformer model.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of query heads (default: 8)
  • :num_kv_heads - Number of key/value heads for GQA (default: 2)
  • :num_layers - Number of decoder blocks (default: 4)
  • :attention_type - Attention mechanism to use (default: :gqa)
    • :gqa — Grouped Query Attention with optional RoPE
    • :lightning — Lightning Attention (hybrid linear/softmax block attention)
    • :dual_chunk — Dual Chunk Attention (intra + inter-chunk for long contexts)
  • :block_size - Block size for Lightning Attention (default: 64)
  • :chunk_size - Chunk size for Dual Chunk Attention (default: 64)
  • :use_rope - Apply Rotary Position Embeddings, GQA only (default: true)
  • :interleave_rope - When true, even-indexed layers (0,2,4...) use RoPE and odd-indexed layers (1,3,5...) use NoPE (content-only attention). Overrides :use_rope on a per-layer basis. This is the iRoPE pattern used by Llama 4. (default: false)
  • :yarn - Enable YaRN context extension for longer sequences (default: false)
  • :yarn_scale - YaRN scaling factor, e.g., 8 extends 2048 to 16384 (default: 8)
  • :yarn_original_max_position - Original trained context length (default: 2048)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a decoder-only model.