Flash Linear Attention — chunked linear attention with feature maps.
Combines the efficiency of linear attention with block-wise computation for better hardware utilization, following the LightningAttention pattern but using explicit feature maps on Q and K.
Key Innovation: Feature-Mapped Chunked Attention
Unlike LightningAttention which uses raw QKV, FlashLinearAttention applies learnable feature maps (ELU+1, ReLU+eps, or identity) to Q and K before computing attention. This creates a true linear attention kernel while maintaining the chunked computation pattern for efficiency.
- Intra-chunk: Quadratic attention on phi(Q), phi(K), V (causal masked)
- Inter-chunk: Linear recurrence via cumulative
S_c = S_{c-1} + phi(K_c)^T @ V_c
Architecture
Input [batch, seq_len, embed_dim]
|
+---------------------------------------------------+
| Flash Linear Attention Block (x num_layers) |
| |
| LayerNorm → Q, K, V projections |
| phi(Q), phi(K) ← feature map (ELU+1/ReLU/id) |
| Reshape to [batch, heads, chunks, chunk_size, d] |
| |
| Intra-chunk: phi(Q)·phi(K)^T · V (causal masked) |
| Inter-chunk: phi(Q) · cumsum(phi(K)^T · V) |
| Output = intra + inter |
| |
| → Residual → LayerNorm → FFN → Residual |
+---------------------------------------------------+
|
[batch, hidden_size]Feature Maps
:elu(default) —1 + ELU(x): smooth, always positive, good gradients:relu—ReLU(x) + eps: sparse but simple:identity—x: no transformation (equivalent to raw linear attention)
Constraints
seq_len must be divisible by chunk_size.
Usage
model = FlashLinearAttention.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 4,
num_layers: 4,
chunk_size: 64,
feature_map: :elu
)References
- "Flash Linear Attention" (Yang et al., 2024)
- flash-linear-attention: https://github.com/fla-org/flash-linear-attention
Summary
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:chunk_size, pos_integer()} | {:feature_map, :elu | :relu | :identity} | {:dropout, float()}
Options for build/1.
Functions
Build a Flash Linear Attention model.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):num_layers- Number of blocks (default: 4):chunk_size- Chunk size for block-wise attention (default: 64):feature_map- Feature map type::elu,:relu, or:identity(default::elu):dropout- Dropout rate (default: 0.1):seq_len/:window_size- Expected sequence length (default: 64)
Returns
An Axon model outputting [batch, hidden_size].
@spec output_size(keyword()) :: pos_integer()
Get the output size of the model.