Edifice.Attention.GQA (Edifice v0.2.0)

Copy Markdown View Source

GQA: Grouped Query Attention.

Grouped Query Attention is an interpolation between Multi-Head Attention (MHA) and Multi-Query Attention (MQA). Groups of query heads share key/value heads, reducing KV cache size while maintaining most of MHA's quality.

Key Innovation: KV Head Sharing

Instead of one KV head per query head (MHA) or one KV head total (MQA), GQA uses G groups where each group of Q heads shares one KV head:

MHA:  Q1-K1-V1  Q2-K2-V2  Q3-K3-V3  Q4-K4-V4   (4 KV heads)
GQA:  Q1-K1-V1  Q2-K1-V1  Q3-K2-V2  Q4-K2-V2   (2 KV heads)
MQA:  Q1-K1-V1  Q2-K1-V1  Q3-K1-V1  Q4-K1-V1   (1 KV head)

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       GQA Transformer Block          |
|                                      |
|  LayerNorm -> GQA Attention          |
|    Q: num_heads projections          |
|    K: num_kv_heads projections       |
|    V: num_kv_heads projections       |
|    K,V repeated for Q head groups    |
|    -> scaled dot-product attention   |
|    -> output projection              |
|  -> Residual                         |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]

Complexity

VariantKV CacheQuality
MHA (G=H)O(H*d)Best
GQA (1<G<H)O(G*d)Near-MHA
MQA (G=1)O(d)Slightly lower

Usage

model = GQA.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 8,
  num_kv_heads: 2,
  num_layers: 4
)

References

  • Paper: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023)
  • Used in: LLaMA 2 70B, Mistral 7B, Gemma

Summary

Types

Options for build/1.

Functions

Build a GQA transformer model for sequence processing.

Build the Grouped Query Attention layer.

Get the output size of a GQA model.

Calculate approximate parameter count for a GQA model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:rope, boolean()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a GQA transformer model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of query heads (default: 8)
  • :num_kv_heads - Number of key/value heads (default: 2)
  • :num_layers - Number of transformer blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :rope - Apply Rotary Position Embedding to Q and K (default: false)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_gqa_attention(input, opts)

@spec build_gqa_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Grouped Query Attention layer.

Projects Q into num_heads groups and K/V into num_kv_heads groups, repeats K/V to match Q head count, then applies scaled dot-product attention.

Options

  • :hidden_size - Hidden dimension
  • :num_heads - Number of query heads
  • :num_kv_heads - Number of key/value heads
  • :rope - Apply standard RoPE (default: false)
  • :yarn - Apply YaRN context extension (default: false)
  • :yarn_scale - YaRN scale factor (default: 8)
  • :yarn_original_max_position - Original context length (default: 2048)
  • :name - Layer name prefix

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a GQA model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a GQA model.