Linear Transformer: Linear attention using kernel feature maps.
Replaces softmax attention with a kernel-based linear attention mechanism, reducing complexity from O(N^2) to O(N) by avoiding explicit computation of the N x N attention matrix.
Key Innovation: Kernel Feature Maps
Standard attention computes: Attn(Q,K,V) = softmax(QK^T/sqrt(d)) * V
Linear attention rewrites this using a feature map phi:
Attn(Q,K,V) = phi(Q) * (phi(K)^T * V) / (phi(Q) * sum(phi(K)))By computing phi(K)^T * V first (a d x d matrix), we avoid the N x N attention matrix entirely. The feature map phi(x) = ELU(x) + 1 ensures non-negative attention weights.
Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| Linear Transformer Block |
| |
| LayerNorm |
| -> Q, K, V projections |
| -> phi(Q), phi(K) feature maps |
| -> KV = phi(K)^T * V [d x d] |
| -> out = phi(Q) * KV [N x d] |
| -> normalize by phi(Q)*sum(K) |
| -> Residual |
| |
| LayerNorm -> FFN -> Residual |
+-------------------------------------+
| (repeat for num_layers)
v
Last timestep -> [batch, hidden_size]Complexity
| Operation | Standard | Linear |
|---|---|---|
| Attention | O(N^2 * d) | O(N * d^2) |
| Memory | O(N^2) | O(N * d) |
| Best when | N < d | N > d |
Linear attention is most beneficial when sequence length N exceeds the head dimension d.
Usage
model = LinearTransformer.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
dropout: 0.1
)References
- Paper: "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (Katharopoulos et al., 2020)
- Feature map: ELU+1 from the original paper
Summary
Functions
Build a Linear Transformer model for sequence processing.
Get the output size of a Linear Transformer model.
Calculate approximate parameter count for a Linear Transformer model.
Recommended default configuration for sequence processing.
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()}
Options for build/1.
Functions
Build a Linear Transformer model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of transformer blocks (default: 4):num_heads- Number of attention heads (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Linear Transformer model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Linear Transformer model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.