# `Edifice.Attention.GQA`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/gqa.ex#L1)

GQA: Grouped Query Attention.

Grouped Query Attention is an interpolation between Multi-Head Attention (MHA)
and Multi-Query Attention (MQA). Groups of query heads share key/value heads,
reducing KV cache size while maintaining most of MHA's quality.

## Key Innovation: KV Head Sharing

Instead of one KV head per query head (MHA) or one KV head total (MQA),
GQA uses G groups where each group of Q heads shares one KV head:

```
MHA:  Q1-K1-V1  Q2-K2-V2  Q3-K3-V3  Q4-K4-V4   (4 KV heads)
GQA:  Q1-K1-V1  Q2-K1-V1  Q3-K2-V2  Q4-K2-V2   (2 KV heads)
MQA:  Q1-K1-V1  Q2-K1-V1  Q3-K1-V1  Q4-K1-V1   (1 KV head)
```

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       GQA Transformer Block          |
|                                      |
|  LayerNorm -> GQA Attention          |
|    Q: num_heads projections          |
|    K: num_kv_heads projections       |
|    V: num_kv_heads projections       |
|    K,V repeated for Q head groups    |
|    -> scaled dot-product attention   |
|    -> output projection              |
|  -> Residual                         |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]
```

## Complexity

| Variant | KV Cache | Quality |
|---------|----------|---------|
| MHA (G=H) | O(H*d) | Best |
| GQA (1<G<H) | O(G*d) | Near-MHA |
| MQA (G=1) | O(d) | Slightly lower |

## Usage

    model = GQA.build(
      embed_dim: 287,
      hidden_size: 256,
      num_heads: 8,
      num_kv_heads: 2,
      num_layers: 4
    )

## References
- Paper: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
  (Ainslie et al., 2023)
- Used in: LLaMA 2 70B, Mistral 7B, Gemma

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:rope, boolean()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a GQA transformer model for sequence processing.

## Options

  - `:embed_dim` - Size of input embedding per timestep (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_heads` - Number of query heads (default: 8)
  - `:num_kv_heads` - Number of key/value heads (default: 2)
  - `:num_layers` - Number of transformer blocks (default: 4)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:rope` - Apply Rotary Position Embedding to Q and K (default: false)
  - `:window_size` - Expected sequence length for JIT optimization (default: 60)

## Returns

  An Axon model that outputs [batch, hidden_size] from the last position.

# `build_gqa_attention`

```elixir
@spec build_gqa_attention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Grouped Query Attention layer.

Projects Q into num_heads groups and K/V into num_kv_heads groups,
repeats K/V to match Q head count, then applies scaled dot-product attention.

## Options

  - `:hidden_size` - Hidden dimension
  - `:num_heads` - Number of query heads
  - `:num_kv_heads` - Number of key/value heads
  - `:rope` - Apply standard RoPE (default: false)
  - `:yarn` - Apply YaRN context extension (default: false)
  - `:yarn_scale` - YaRN scale factor (default: 8)
  - `:yarn_original_max_position` - Original context length (default: 2048)
  - `:name` - Layer name prefix

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a GQA model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a GQA model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended default configuration for sequence processing.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
