GQA: Grouped Query Attention.
Grouped Query Attention is an interpolation between Multi-Head Attention (MHA) and Multi-Query Attention (MQA). Groups of query heads share key/value heads, reducing KV cache size while maintaining most of MHA's quality.
Key Innovation: KV Head Sharing
Instead of one KV head per query head (MHA) or one KV head total (MQA), GQA uses G groups where each group of Q heads shares one KV head:
MHA: Q1-K1-V1 Q2-K2-V2 Q3-K3-V3 Q4-K4-V4 (4 KV heads)
GQA: Q1-K1-V1 Q2-K1-V1 Q3-K2-V2 Q4-K2-V2 (2 KV heads)
MQA: Q1-K1-V1 Q2-K1-V1 Q3-K1-V1 Q4-K1-V1 (1 KV head)Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| GQA Transformer Block |
| |
| LayerNorm -> GQA Attention |
| Q: num_heads projections |
| K: num_kv_heads projections |
| V: num_kv_heads projections |
| K,V repeated for Q head groups |
| -> scaled dot-product attention |
| -> output projection |
| -> Residual |
| LayerNorm -> FFN -> Residual |
+-------------------------------------+
| (repeat for num_layers)
v
[batch, hidden_size]Complexity
| Variant | KV Cache | Quality |
|---|---|---|
| MHA (G=H) | O(H*d) | Best |
| GQA (1<G<H) | O(G*d) | Near-MHA |
| MQA (G=1) | O(d) | Slightly lower |
Usage
model = GQA.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 8,
num_kv_heads: 2,
num_layers: 4
)References
- Paper: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023)
- Used in: LLaMA 2 70B, Mistral 7B, Gemma
Summary
Functions
Build a GQA transformer model for sequence processing.
Build the Grouped Query Attention layer.
Get the output size of a GQA model.
Calculate approximate parameter count for a GQA model.
Recommended default configuration for sequence processing.
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:rope, boolean()}
Options for build/1.
Functions
Build a GQA transformer model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of query heads (default: 8):num_kv_heads- Number of key/value heads (default: 2):num_layers- Number of transformer blocks (default: 4):dropout- Dropout rate (default: 0.1):rope- Apply Rotary Position Embedding to Q and K (default: false):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build the Grouped Query Attention layer.
Projects Q into num_heads groups and K/V into num_kv_heads groups, repeats K/V to match Q head count, then applies scaled dot-product attention.
Options
:hidden_size- Hidden dimension:num_heads- Number of query heads:num_kv_heads- Number of key/value heads:rope- Apply standard RoPE (default: false):yarn- Apply YaRN context extension (default: false):yarn_scale- YaRN scale factor (default: 8):yarn_original_max_position- Original context length (default: 2048):name- Layer name prefix
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a GQA model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a GQA model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.