Infini-Attention: local windowed attention + compressive memory.
Extends standard multi-head attention with a compressive memory system that enables effectively unbounded context length. Each layer maintains a learnable memory matrix that accumulates information from past segments.
Key Innovation
For each segment of the input:
- Standard local attention within the segment (captures fine-grained patterns)
- Memory retrieval: sigma(Q) @ M / (sigma(Q) @ z) where sigma = ELU + 1
- Memory update: M += sigma(K)^T @ V, z += sum(sigma(K))
- A learnable gate blends local and memory outputs
Architecture
Input [batch, seq_len, embed_dim]
|
Input projection to hidden_size
|
+----------------------------------------------+
| Infini-Attention Block (x num_layers) |
| |
| LayerNorm -> Infini-Attention |
| Split into segments of segment_size |
| Per segment: |
| Local multi-head attention |
| Memory retrieval + update |
| Gated blend of local + memory |
| -> Residual |
| LayerNorm -> FFN -> Residual |
+----------------------------------------------+
|
Final LayerNorm
|
Last timestep -> [batch, hidden_size]Usage
model = InfiniAttention.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 4,
segment_size: 32,
num_layers: 4
)References
- "Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention" (Munkhdalai et al., 2024)
Summary
Functions
Build an Infini-Attention model.
Build the Infini-Attention layer with segmented local attention and compressive memory.
Get the output size of an Infini-Attention model.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:segment_size, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build an Infini-Attention model.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):segment_size- Size of each local attention segment (default: 32):num_layers- Number of transformer blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build the Infini-Attention layer with segmented local attention and compressive memory.
@spec output_size(keyword()) :: pos_integer()
Get the output size of an Infini-Attention model.