Edifice.Attention.FlashLinearAttention (Edifice v0.2.0)

Copy Markdown View Source

Flash Linear Attention — chunked linear attention with feature maps.

Combines the efficiency of linear attention with block-wise computation for better hardware utilization, following the LightningAttention pattern but using explicit feature maps on Q and K.

Key Innovation: Feature-Mapped Chunked Attention

Unlike LightningAttention which uses raw QKV, FlashLinearAttention applies learnable feature maps (ELU+1, ReLU+eps, or identity) to Q and K before computing attention. This creates a true linear attention kernel while maintaining the chunked computation pattern for efficiency.

  • Intra-chunk: Quadratic attention on phi(Q), phi(K), V (causal masked)
  • Inter-chunk: Linear recurrence via cumulative S_c = S_{c-1} + phi(K_c)^T @ V_c

Architecture

Input [batch, seq_len, embed_dim]
      |
+---------------------------------------------------+
|  Flash Linear Attention Block (x num_layers)       |
|                                                     |
|  LayerNorm  Q, K, V projections                   |
|  phi(Q), phi(K)  feature map (ELU+1/ReLU/id)     |
|  Reshape to [batch, heads, chunks, chunk_size, d]  |
|                                                     |
|  Intra-chunk: phi(Q)·phi(K)^T · V (causal masked) |
|  Inter-chunk: phi(Q) · cumsum(phi(K)^T · V)       |
|  Output = intra + inter                            |
|                                                     |
|   Residual  LayerNorm  FFN  Residual           |
+---------------------------------------------------+
      |
[batch, hidden_size]

Feature Maps

  • :elu (default) — 1 + ELU(x): smooth, always positive, good gradients
  • :reluReLU(x) + eps: sparse but simple
  • :identityx: no transformation (equivalent to raw linear attention)

Constraints

seq_len must be divisible by chunk_size.

Usage

model = FlashLinearAttention.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 4,
  num_layers: 4,
  chunk_size: 64,
  feature_map: :elu
)

References

Summary

Types

Options for build/1.

Functions

Build a Flash Linear Attention model.

Get the output size of the model.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:chunk_size, pos_integer()}
  | {:feature_map, :elu | :relu | :identity}
  | {:dropout, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Flash Linear Attention model.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :num_layers - Number of blocks (default: 4)
  • :chunk_size - Chunk size for block-wise attention (default: 64)
  • :feature_map - Feature map type: :elu, :relu, or :identity (default: :elu)
  • :dropout - Dropout rate (default: 0.1)
  • :seq_len / :window_size - Expected sequence length (default: 64)

Returns

An Axon model outputting [batch, hidden_size].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the model.