Edifice.Attention.MultiHead (Edifice v0.2.0)

Copy Markdown View Source

Temporal attention mechanisms for sequence processing.

Provides two main architectures:

Option C: Sliding Window Attention

Efficient attention that only looks at the last K timesteps:

Step t-5   Step t-4   Step t-3   Step t-2   Step t-1   Step t
   |         |         |         |         |         |
   +---------+---------+---------+---------+---------+
                       Attend to window
                             |
                             v
                      Current Output

O(K^2) complexity instead of O(N^2) - practical for real-time.

Option B: Hybrid LSTM + Attention

LSTM compresses temporal information, then attention finds long-range patterns:

Frames -> LSTM -> [h1, h2, ..., hN] -> Self-Attention -> Output

Best of both worlds:

  • LSTM captures local sequential patterns
  • Attention finds sparse long-range dependencies

Why Attention Helps Temporal Processing

  1. Direct timestep access: "What happened exactly 12 steps ago?"
  2. Learned relevance: Model decides which past timesteps matter
  3. Parallel training: Unlike LSTM, attention can process all timesteps simultaneously
  4. Interpretable: Attention weights show what the model focuses on

Usage

# Sliding window model
model = MultiHead.build_sliding_window(
  embed_dim: 1024,
  window_size: 60,
  num_heads: 4,
  head_dim: 64
)

# Hybrid LSTM + Attention
model = MultiHead.build_hybrid(
  embed_dim: 1024,
  lstm_hidden: 256,
  num_heads: 4,
  head_dim: 64
)

Summary

Types

Options for build/1.

Functions

Add sinusoidal positional encoding to input.

Build a multi-head attention model.

Build a hybrid LSTM + Attention model.

Build hybrid model with additional MLP layers on top.

Build a complete sliding window attention model.

Create a causal (autoregressive) attention mask.

Chunked attention for reduced peak memory usage.

Memory-efficient attention using online softmax normalization.

Build multi-head attention with configurable heads and head dimension.

Get the output dimension for a model configuration.

Apply layer normalization to Q or K tensors.

Recommended default configuration for temporal sequence processing.

Build a multi-head self-attention Axon layer.

Build a sliding window attention layer.

Create a sliding window attention mask.

Types

build_opt()

@type build_opt() :: {:num_heads, pos_integer()}

Options for build/1.

Functions

add_positional_encoding(input, opts \\ [])

@spec add_positional_encoding(
  Axon.t(),
  keyword()
) :: Axon.t()

Add sinusoidal positional encoding to input.

Uses sin for position encoding - compatible with Axon's JIT compilation. Each position gets a unique encoding based on sine waves at different frequencies across the embedding dimensions.

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a multi-head attention model.

This is the standard entry point used by Edifice.build(:attention, opts). Delegates to build_sliding_window/1 which provides efficient attention that only attends to recent timesteps.

For a hybrid LSTM + attention model, use build_hybrid/1 directly.

Options

  • :embed_dim - Input embedding size (required)
  • :hidden_size - Total hidden dimension; overrides num_heads * head_dim (optional)
  • :window_size - Attention window / sequence length (default: 60)
  • :num_heads - Attention heads (default: 4)
  • :head_dim - Dimension per head (default: 64)
  • :num_layers - Number of attention layers (default: 2)
  • :ffn_dim - Feed-forward dimension (default: 256)
  • :dropout - Dropout rate (default: 0.1)

Returns

Model that outputs [batch, hidden_size] from last position.

build_hybrid(opts \\ [])

@spec build_hybrid(keyword()) :: Axon.t()

Build a hybrid LSTM + Attention model.

LSTM captures local sequential patterns, attention finds long-range dependencies.

Architecture

Frames -> LSTM -> Hidden States -> Self-Attention -> Output

Options

  • :embed_dim - Input embedding size (required)
  • :lstm_hidden - LSTM hidden size (default: 256)
  • :lstm_layers - Number of LSTM layers (default: 1)
  • :num_heads - Attention heads (default: 4)
  • :head_dim - Dimension per head (default: 64)
  • :dropout - Dropout rate (default: 0.1)

Returns

Model that outputs [batch, hidden_size] combining LSTM and attention.

build_hybrid_mlp(opts \\ [])

@spec build_hybrid_mlp(keyword()) :: Axon.t()

Build hybrid model with additional MLP layers on top.

Good for policy/value heads that need more non-linearity.

build_sliding_window(opts \\ [])

@spec build_sliding_window(keyword()) :: Axon.t()

Build a complete sliding window attention model.

Efficient for real-time inference - only attends to recent timesteps.

Options

  • :embed_dim - Input embedding size (required)
  • :window_size - Attention window (default: 60)
  • :num_heads - Attention heads (default: 4)
  • :head_dim - Dimension per head (default: 64)
  • :num_layers - Number of attention layers (default: 2)
  • :ffn_dim - Feed-forward dimension (default: 256)
  • :dropout - Dropout rate (default: 0.1)

Returns

Model that outputs [batch, hidden_size] from last position.

causal_mask(seq_len)

@spec causal_mask(non_neg_integer()) :: Nx.Tensor.t()

Create a causal (autoregressive) attention mask.

Each position can only attend to itself and previous positions.

chunked_attention(query, key, value, opts \\ [])

@spec chunked_attention(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()

Chunked attention for reduced peak memory usage.

Processes query in chunks, computing attention for each chunk against all keys. This reduces peak memory from O(seq^2) to O(seq x chunk_size) while producing identical results to standard attention.

Parameters

  • query - Query tensor [batch, seq_q, dim]
  • key - Key tensor [batch, seq_k, dim]
  • value - Value tensor [batch, seq_k, dim]
  • opts - Options:
    • :chunk_size - Size of query chunks (default: 32)
    • :mask - Attention mask (will be chunked automatically)

Returns

Attention output [batch, seq_q, dim] - identical to scaled_dot_product_attention

Memory Comparison

For seq_len=128, batch=32, dim=256:

  • Standard: 32 x 128 x 128 x 4 bytes = 2MB peak for scores
  • Chunked (chunk=32): 32 x 32 x 128 x 4 bytes = 512KB peak (4x reduction)

memory_efficient_attention(query, key, value, opts \\ [])

@spec memory_efficient_attention(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) ::
  Nx.Tensor.t()

Memory-efficient attention using online softmax normalization.

Achieves true O(n) memory by processing key/value in chunks with online softmax, never materializing the full attention matrix. This is based on the algorithm from "Self-attention Does Not Need O(n^2) Memory" (Rabe & Staats, 2021).

Algorithm

Instead of computing the full attention matrix, we:

  1. Process K/V in chunks
  2. For each chunk, compute partial attention scores
  3. Use online softmax to combine results: track running max and sum
  4. Update output with proper normalization

Parameters

  • query - Query tensor [batch, seq_q, dim]
  • key - Key tensor [batch, seq_k, dim]
  • value - Value tensor [batch, seq_k, dim]
  • opts - Options:
    • :chunk_size - Size of K/V chunks (default: 32)
    • :causal - Use causal masking (default: false)

Returns

Attention output [batch, seq_q, dim]

Memory Comparison

For seq_len=128, batch=32, dim=256:

  • Standard: 32 x 128 x 128 x 4 bytes = 2MB (full attention matrix)
  • Memory-efficient: 32 x 128 x 32 x 4 bytes = 512KB (one chunk at a time)

Notes

  • Slightly slower than standard attention due to online softmax overhead
  • Output may have minor numerical differences (< 1e-5) due to different summation order in softmax
  • Causal masking is applied per-chunk

multi_head_attention(input, opts \\ [])

@spec multi_head_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build multi-head attention with configurable heads and head dimension.

Computes hidden_size = num_heads * head_dim and delegates to self_attention/2 with proper multi-head reshaping.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output dimension for a model configuration.

qk_layer_norm(tensor)

@spec qk_layer_norm(Nx.Tensor.t()) :: Nx.Tensor.t()

Apply layer normalization to Q or K tensors.

QK LayerNorm normalizes across the feature dimension to prevent attention score explosion in deep networks.

scaled_dot_product_attention(query, key, value, opts \\ [])

@spec scaled_dot_product_attention(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) ::
  Nx.Tensor.t()

Scaled dot-product attention.

Computes: softmax(QK^T / sqrt(d_k)) * V

Parameters

  • query - Query tensor [batch, seq_q, dim]
  • key - Key tensor [batch, seq_k, dim]
  • value - Value tensor [batch, seq_k, dim]
  • opts - Options including :mask for causal/window masking

Returns

Attention output [batch, seq_q, dim]

self_attention(input, opts \\ [])

@spec self_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a multi-head self-attention Axon layer.

Properly reshapes Q, K, V to [batch, num_heads, seq, head_dim] so each head computes its own independent attention pattern, then reshapes back.

Options

  • :hidden_size - Hidden dimension = num_heads * head_dim (default: 256)
  • :num_heads - Number of attention heads (default: 1)
  • :dropout - Dropout rate (default: 0.1)
  • :causal - Use causal masking (default: true)
  • :qk_layernorm - Normalize Q and K before attention (stabilizes training, default: false)
  • :rope - Apply Rotary Position Embedding to Q and K (default: false)
  • :chunked - Use chunked attention for lower memory (default: false)
  • :memory_efficient - Use memory-efficient attention with online softmax for true O(n) memory (default: false)
  • :chunk_size - Chunk size for chunked/memory-efficient attention (default: 32)
  • :name - Layer name prefix

sliding_window_attention(input, opts \\ [])

@spec sliding_window_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a sliding window attention layer.

More efficient than full attention - O(K^2) per position instead of O(N^2).

Options

  • :window_size - Attention window size (default: 60)
  • :hidden_size - Hidden dimension (default: 256)
  • :mask - Pre-computed attention mask (recommended for efficient compilation)
  • :qk_layernorm - Normalize Q and K before attention (stabilizes training, default: false)
  • :chunked - Use chunked attention for lower memory (default: false)
  • :memory_efficient - Use memory-efficient attention with online softmax for true O(n) memory (default: false)
  • :chunk_size - Chunk size for chunked/memory-efficient attention (default: 32)
  • :name - Layer name prefix

window_mask(seq_len, window_size)

@spec window_mask(non_neg_integer(), non_neg_integer()) :: Nx.Tensor.t()

Create a sliding window attention mask.

Each position can only attend to positions within the window.