Temporal attention mechanisms for sequence processing.
Provides two main architectures:
Option C: Sliding Window Attention
Efficient attention that only looks at the last K timesteps:
Step t-5 Step t-4 Step t-3 Step t-2 Step t-1 Step t
| | | | | |
+---------+---------+---------+---------+---------+
Attend to window
|
v
Current OutputO(K^2) complexity instead of O(N^2) - practical for real-time.
Option B: Hybrid LSTM + Attention
LSTM compresses temporal information, then attention finds long-range patterns:
Frames -> LSTM -> [h1, h2, ..., hN] -> Self-Attention -> OutputBest of both worlds:
- LSTM captures local sequential patterns
- Attention finds sparse long-range dependencies
Why Attention Helps Temporal Processing
- Direct timestep access: "What happened exactly 12 steps ago?"
- Learned relevance: Model decides which past timesteps matter
- Parallel training: Unlike LSTM, attention can process all timesteps simultaneously
- Interpretable: Attention weights show what the model focuses on
Usage
# Sliding window model
model = MultiHead.build_sliding_window(
embed_dim: 1024,
window_size: 60,
num_heads: 4,
head_dim: 64
)
# Hybrid LSTM + Attention
model = MultiHead.build_hybrid(
embed_dim: 1024,
lstm_hidden: 256,
num_heads: 4,
head_dim: 64
)
Summary
Functions
Add sinusoidal positional encoding to input.
Build a multi-head attention model.
Build a hybrid LSTM + Attention model.
Build hybrid model with additional MLP layers on top.
Build a complete sliding window attention model.
Create a causal (autoregressive) attention mask.
Chunked attention for reduced peak memory usage.
Memory-efficient attention using online softmax normalization.
Build multi-head attention with configurable heads and head dimension.
Get the output dimension for a model configuration.
Apply layer normalization to Q or K tensors.
Recommended default configuration for temporal sequence processing.
Scaled dot-product attention.
Build a multi-head self-attention Axon layer.
Build a sliding window attention layer.
Create a sliding window attention mask.
Types
@type build_opt() :: {:num_heads, pos_integer()}
Options for build/1.
Functions
Add sinusoidal positional encoding to input.
Uses sin for position encoding - compatible with Axon's JIT compilation. Each position gets a unique encoding based on sine waves at different frequencies across the embedding dimensions.
Build a multi-head attention model.
This is the standard entry point used by Edifice.build(:attention, opts).
Delegates to build_sliding_window/1 which provides efficient attention
that only attends to recent timesteps.
For a hybrid LSTM + attention model, use build_hybrid/1 directly.
Options
:embed_dim- Input embedding size (required):hidden_size- Total hidden dimension; overrides num_heads * head_dim (optional):window_size- Attention window / sequence length (default: 60):num_heads- Attention heads (default: 4):head_dim- Dimension per head (default: 64):num_layers- Number of attention layers (default: 2):ffn_dim- Feed-forward dimension (default: 256):dropout- Dropout rate (default: 0.1)
Returns
Model that outputs [batch, hidden_size] from last position.
Build a hybrid LSTM + Attention model.
LSTM captures local sequential patterns, attention finds long-range dependencies.
Architecture
Frames -> LSTM -> Hidden States -> Self-Attention -> OutputOptions
:embed_dim- Input embedding size (required):lstm_hidden- LSTM hidden size (default: 256):lstm_layers- Number of LSTM layers (default: 1):num_heads- Attention heads (default: 4):head_dim- Dimension per head (default: 64):dropout- Dropout rate (default: 0.1)
Returns
Model that outputs [batch, hidden_size] combining LSTM and attention.
Build hybrid model with additional MLP layers on top.
Good for policy/value heads that need more non-linearity.
Build a complete sliding window attention model.
Efficient for real-time inference - only attends to recent timesteps.
Options
:embed_dim- Input embedding size (required):window_size- Attention window (default: 60):num_heads- Attention heads (default: 4):head_dim- Dimension per head (default: 64):num_layers- Number of attention layers (default: 2):ffn_dim- Feed-forward dimension (default: 256):dropout- Dropout rate (default: 0.1)
Returns
Model that outputs [batch, hidden_size] from last position.
@spec causal_mask(non_neg_integer()) :: Nx.Tensor.t()
Create a causal (autoregressive) attention mask.
Each position can only attend to itself and previous positions.
@spec chunked_attention(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Chunked attention for reduced peak memory usage.
Processes query in chunks, computing attention for each chunk against all keys. This reduces peak memory from O(seq^2) to O(seq x chunk_size) while producing identical results to standard attention.
Parameters
query- Query tensor [batch, seq_q, dim]key- Key tensor [batch, seq_k, dim]value- Value tensor [batch, seq_k, dim]opts- Options::chunk_size- Size of query chunks (default: 32):mask- Attention mask (will be chunked automatically)
Returns
Attention output [batch, seq_q, dim] - identical to scaled_dot_product_attention
Memory Comparison
For seq_len=128, batch=32, dim=256:
- Standard: 32 x 128 x 128 x 4 bytes = 2MB peak for scores
- Chunked (chunk=32): 32 x 32 x 128 x 4 bytes = 512KB peak (4x reduction)
@spec memory_efficient_attention( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Memory-efficient attention using online softmax normalization.
Achieves true O(n) memory by processing key/value in chunks with online softmax, never materializing the full attention matrix. This is based on the algorithm from "Self-attention Does Not Need O(n^2) Memory" (Rabe & Staats, 2021).
Algorithm
Instead of computing the full attention matrix, we:
- Process K/V in chunks
- For each chunk, compute partial attention scores
- Use online softmax to combine results: track running max and sum
- Update output with proper normalization
Parameters
query- Query tensor [batch, seq_q, dim]key- Key tensor [batch, seq_k, dim]value- Value tensor [batch, seq_k, dim]opts- Options::chunk_size- Size of K/V chunks (default: 32):causal- Use causal masking (default: false)
Returns
Attention output [batch, seq_q, dim]
Memory Comparison
For seq_len=128, batch=32, dim=256:
- Standard: 32 x 128 x 128 x 4 bytes = 2MB (full attention matrix)
- Memory-efficient: 32 x 128 x 32 x 4 bytes = 512KB (one chunk at a time)
Notes
- Slightly slower than standard attention due to online softmax overhead
- Output may have minor numerical differences (< 1e-5) due to different summation order in softmax
- Causal masking is applied per-chunk
Build multi-head attention with configurable heads and head dimension.
Computes hidden_size = num_heads * head_dim and delegates to self_attention/2
with proper multi-head reshaping.
@spec output_size(keyword()) :: non_neg_integer()
Get the output dimension for a model configuration.
@spec qk_layer_norm(Nx.Tensor.t()) :: Nx.Tensor.t()
Apply layer normalization to Q or K tensors.
QK LayerNorm normalizes across the feature dimension to prevent attention score explosion in deep networks.
@spec recommended_defaults() :: keyword()
Recommended default configuration for temporal sequence processing.
@spec scaled_dot_product_attention( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Scaled dot-product attention.
Computes: softmax(QK^T / sqrt(d_k)) * V
Parameters
query- Query tensor [batch, seq_q, dim]key- Key tensor [batch, seq_k, dim]value- Value tensor [batch, seq_k, dim]opts- Options including :mask for causal/window masking
Returns
Attention output [batch, seq_q, dim]
Build a multi-head self-attention Axon layer.
Properly reshapes Q, K, V to [batch, num_heads, seq, head_dim] so each
head computes its own independent attention pattern, then reshapes back.
Options
:hidden_size- Hidden dimension = num_heads * head_dim (default: 256):num_heads- Number of attention heads (default: 1):dropout- Dropout rate (default: 0.1):causal- Use causal masking (default: true):qk_layernorm- Normalize Q and K before attention (stabilizes training, default: false):rope- Apply Rotary Position Embedding to Q and K (default: false):chunked- Use chunked attention for lower memory (default: false):memory_efficient- Use memory-efficient attention with online softmax for true O(n) memory (default: false):chunk_size- Chunk size for chunked/memory-efficient attention (default: 32):name- Layer name prefix
Build a sliding window attention layer.
More efficient than full attention - O(K^2) per position instead of O(N^2).
Options
:window_size- Attention window size (default: 60):hidden_size- Hidden dimension (default: 256):mask- Pre-computed attention mask (recommended for efficient compilation):qk_layernorm- Normalize Q and K before attention (stabilizes training, default: false):chunked- Use chunked attention for lower memory (default: false):memory_efficient- Use memory-efficient attention with online softmax for true O(n) memory (default: false):chunk_size- Chunk size for chunked/memory-efficient attention (default: 32):name- Layer name prefix
@spec window_mask(non_neg_integer(), non_neg_integer()) :: Nx.Tensor.t()
Create a sliding window attention mask.
Each position can only attend to positions within the window.