# `Edifice.Attention.MultiHead`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/multi_head.ex#L1)

Temporal attention mechanisms for sequence processing.

Provides two main architectures:

## Option C: Sliding Window Attention

Efficient attention that only looks at the last K timesteps:

```
Step t-5   Step t-4   Step t-3   Step t-2   Step t-1   Step t
   |         |         |         |         |         |
   +---------+---------+---------+---------+---------+
                       Attend to window
                             |
                             v
                      Current Output
```

O(K^2) complexity instead of O(N^2) - practical for real-time.

## Option B: Hybrid LSTM + Attention

LSTM compresses temporal information, then attention finds long-range patterns:

```
Frames -> LSTM -> [h1, h2, ..., hN] -> Self-Attention -> Output
```

Best of both worlds:
- LSTM captures local sequential patterns
- Attention finds sparse long-range dependencies

## Why Attention Helps Temporal Processing

1. **Direct timestep access**: "What happened exactly 12 steps ago?"
2. **Learned relevance**: Model decides which past timesteps matter
3. **Parallel training**: Unlike LSTM, attention can process all timesteps simultaneously
4. **Interpretable**: Attention weights show what the model focuses on

## Usage

    # Sliding window model
    model = MultiHead.build_sliding_window(
      embed_dim: 1024,
      window_size: 60,
      num_heads: 4,
      head_dim: 64
    )

    # Hybrid LSTM + Attention
    model = MultiHead.build_hybrid(
      embed_dim: 1024,
      lstm_hidden: 256,
      num_heads: 4,
      head_dim: 64
    )

# `build_opt`

```elixir
@type build_opt() :: {:num_heads, pos_integer()}
```

Options for `build/1`.

# `add_positional_encoding`

```elixir
@spec add_positional_encoding(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Add sinusoidal positional encoding to input.

Uses sin for position encoding - compatible with Axon's JIT compilation.
Each position gets a unique encoding based on sine waves at different
frequencies across the embedding dimensions.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a multi-head attention model.

This is the standard entry point used by `Edifice.build(:attention, opts)`.
Delegates to `build_sliding_window/1` which provides efficient attention
that only attends to recent timesteps.

For a hybrid LSTM + attention model, use `build_hybrid/1` directly.

## Options
  - `:embed_dim` - Input embedding size (required)
  - `:hidden_size` - Total hidden dimension; overrides num_heads * head_dim (optional)
  - `:window_size` - Attention window / sequence length (default: 60)
  - `:num_heads` - Attention heads (default: 4)
  - `:head_dim` - Dimension per head (default: 64)
  - `:num_layers` - Number of attention layers (default: 2)
  - `:ffn_dim` - Feed-forward dimension (default: 256)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns
  Model that outputs [batch, hidden_size] from last position.

# `build_hybrid`

```elixir
@spec build_hybrid(keyword()) :: Axon.t()
```

Build a hybrid LSTM + Attention model.

LSTM captures local sequential patterns, attention finds long-range dependencies.

## Architecture
```
Frames -> LSTM -> Hidden States -> Self-Attention -> Output
```

## Options
  - `:embed_dim` - Input embedding size (required)
  - `:lstm_hidden` - LSTM hidden size (default: 256)
  - `:lstm_layers` - Number of LSTM layers (default: 1)
  - `:num_heads` - Attention heads (default: 4)
  - `:head_dim` - Dimension per head (default: 64)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns
  Model that outputs [batch, hidden_size] combining LSTM and attention.

# `build_hybrid_mlp`

```elixir
@spec build_hybrid_mlp(keyword()) :: Axon.t()
```

Build hybrid model with additional MLP layers on top.

Good for policy/value heads that need more non-linearity.

# `build_sliding_window`

```elixir
@spec build_sliding_window(keyword()) :: Axon.t()
```

Build a complete sliding window attention model.

Efficient for real-time inference - only attends to recent timesteps.

## Options
  - `:embed_dim` - Input embedding size (required)
  - `:window_size` - Attention window (default: 60)
  - `:num_heads` - Attention heads (default: 4)
  - `:head_dim` - Dimension per head (default: 64)
  - `:num_layers` - Number of attention layers (default: 2)
  - `:ffn_dim` - Feed-forward dimension (default: 256)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns
  Model that outputs [batch, hidden_size] from last position.

# `causal_mask`

```elixir
@spec causal_mask(non_neg_integer()) :: Nx.Tensor.t()
```

Create a causal (autoregressive) attention mask.

Each position can only attend to itself and previous positions.

# `chunked_attention`

```elixir
@spec chunked_attention(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()
```

Chunked attention for reduced peak memory usage.

Processes query in chunks, computing attention for each chunk against all keys.
This reduces peak memory from O(seq^2) to O(seq x chunk_size) while producing
identical results to standard attention.

## Parameters
  - `query` - Query tensor [batch, seq_q, dim]
  - `key` - Key tensor [batch, seq_k, dim]
  - `value` - Value tensor [batch, seq_k, dim]
  - `opts` - Options:
    - `:chunk_size` - Size of query chunks (default: 32)
    - `:mask` - Attention mask (will be chunked automatically)

## Returns
  Attention output [batch, seq_q, dim] - identical to scaled_dot_product_attention

## Memory Comparison
  For seq_len=128, batch=32, dim=256:
  - Standard: 32 x 128 x 128 x 4 bytes = 2MB peak for scores
  - Chunked (chunk=32): 32 x 32 x 128 x 4 bytes = 512KB peak (4x reduction)

# `memory_efficient_attention`

```elixir
@spec memory_efficient_attention(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) ::
  Nx.Tensor.t()
```

Memory-efficient attention using online softmax normalization.

Achieves true O(n) memory by processing key/value in chunks with online
softmax, never materializing the full attention matrix. This is based on
the algorithm from "Self-attention Does Not Need O(n^2) Memory" (Rabe & Staats, 2021).

## Algorithm

Instead of computing the full attention matrix, we:
1. Process K/V in chunks
2. For each chunk, compute partial attention scores
3. Use online softmax to combine results: track running max and sum
4. Update output with proper normalization

## Parameters
  - `query` - Query tensor [batch, seq_q, dim]
  - `key` - Key tensor [batch, seq_k, dim]
  - `value` - Value tensor [batch, seq_k, dim]
  - `opts` - Options:
    - `:chunk_size` - Size of K/V chunks (default: 32)
    - `:causal` - Use causal masking (default: false)

## Returns
  Attention output [batch, seq_q, dim]

## Memory Comparison
  For seq_len=128, batch=32, dim=256:
  - Standard: 32 x 128 x 128 x 4 bytes = 2MB (full attention matrix)
  - Memory-efficient: 32 x 128 x 32 x 4 bytes = 512KB (one chunk at a time)

## Notes
  - Slightly slower than standard attention due to online softmax overhead
  - Output may have minor numerical differences (< 1e-5) due to different
    summation order in softmax
  - Causal masking is applied per-chunk

# `multi_head_attention`

```elixir
@spec multi_head_attention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build multi-head attention with configurable heads and head dimension.

Computes `hidden_size = num_heads * head_dim` and delegates to `self_attention/2`
with proper multi-head reshaping.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output dimension for a model configuration.

# `qk_layer_norm`

```elixir
@spec qk_layer_norm(Nx.Tensor.t()) :: Nx.Tensor.t()
```

Apply layer normalization to Q or K tensors.

QK LayerNorm normalizes across the feature dimension to prevent
attention score explosion in deep networks.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended default configuration for temporal sequence processing.

# `scaled_dot_product_attention`

```elixir
@spec scaled_dot_product_attention(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) ::
  Nx.Tensor.t()
```

Scaled dot-product attention.

Computes: softmax(QK^T / sqrt(d_k)) * V

## Parameters
  - `query` - Query tensor [batch, seq_q, dim]
  - `key` - Key tensor [batch, seq_k, dim]
  - `value` - Value tensor [batch, seq_k, dim]
  - `opts` - Options including :mask for causal/window masking

## Returns
  Attention output [batch, seq_q, dim]

# `self_attention`

```elixir
@spec self_attention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a multi-head self-attention Axon layer.

Properly reshapes Q, K, V to `[batch, num_heads, seq, head_dim]` so each
head computes its own independent attention pattern, then reshapes back.

## Options
  - `:hidden_size` - Hidden dimension = num_heads * head_dim (default: 256)
  - `:num_heads` - Number of attention heads (default: 1)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:causal` - Use causal masking (default: true)
  - `:qk_layernorm` - Normalize Q and K before attention (stabilizes training, default: false)
  - `:rope` - Apply Rotary Position Embedding to Q and K (default: false)
  - `:chunked` - Use chunked attention for lower memory (default: false)
  - `:memory_efficient` - Use memory-efficient attention with online softmax for true O(n) memory (default: false)
  - `:chunk_size` - Chunk size for chunked/memory-efficient attention (default: 32)
  - `:name` - Layer name prefix

# `sliding_window_attention`

```elixir
@spec sliding_window_attention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a sliding window attention layer.

More efficient than full attention - O(K^2) per position instead of O(N^2).

## Options
  - `:window_size` - Attention window size (default: 60)
  - `:hidden_size` - Hidden dimension (default: 256)
  - `:mask` - Pre-computed attention mask (recommended for efficient compilation)
  - `:qk_layernorm` - Normalize Q and K before attention (stabilizes training, default: false)
  - `:chunked` - Use chunked attention for lower memory (default: false)
  - `:memory_efficient` - Use memory-efficient attention with online softmax for true O(n) memory (default: false)
  - `:chunk_size` - Chunk size for chunked/memory-efficient attention (default: 32)
  - `:name` - Layer name prefix

# `window_mask`

```elixir
@spec window_mask(non_neg_integer(), non_neg_integer()) :: Nx.Tensor.t()
```

Create a sliding window attention mask.

Each position can only attend to positions within the window.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
