Based: Linear attention with Taylor expansion feature map.
Replaces the quadratic softmax(QK^T) attention with a linear approximation using Taylor-expanded feature maps. Instead of computing the full attention matrix, Based projects Q and K through a polynomial feature map phi(x) and computes attention in linear time.
Key Innovation
The Taylor feature map approximates softmax attention:
- phi(x) = [1, x, x^2/sqrt(2!), ...] for Taylor order N
- Linear attention: output = phi(Q) @ (phi(K)^T @ V) / (phi(Q) @ sum(phi(K)))
- This avoids the O(n^2) softmax(QK^T) computation
Architecture
Input [batch, seq_len, embed_dim]
|
Input projection to hidden_size
|
+--------------------------------------+
| Based Block (x num_layers) |
| |
| LayerNorm -> Based Linear Attn |
| Q, K projections + Taylor phi() |
| Linear attention via phi(Q/K) |
| -> Residual |
| LayerNorm -> FFN -> Residual |
+--------------------------------------+
|
Final LayerNorm
|
Last timestep -> [batch, hidden_size]Complexity
| Mechanism | Time | Space |
|---|---|---|
| Softmax attention | O(n^2 d) | O(n^2) |
| Based (Taylor) | O(n d^2 p) | O(d^2 p) |
Where p = Taylor order, typically 2-3.
Usage
model = Based.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 4,
taylor_order: 2,
num_layers: 4
)References
- "Simple linear attention language models balance the recall-throughput tradeoff" (Arora et al., 2024)
Summary
Functions
Build a Based linear attention model.
Build the Based linear attention layer with Taylor feature map.
Get the output size of a Based model.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:taylor_order, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Based linear attention model.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):taylor_order- Order of Taylor expansion for feature map (default: 2):num_layers- Number of transformer blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build the Based linear attention layer with Taylor feature map.
Projects to Q, K, V, applies Taylor feature map to Q and K, then computes linear attention.
@spec output_size(keyword()) :: pos_integer()
Get the output size of a Based model.