GLA: Gated Linear Attention with data-dependent gating.
GLA combines the efficiency of linear attention (O(L) complexity) with data-dependent gating for improved expressiveness. It's particularly effective on short sequences (<2K tokens) where it can outperform FlashAttention-2.
Key Innovation: Data-Dependent Gating
Unlike standard linear attention which uses fixed feature maps, GLA computes gates from the input that control information flow:
output[t] = gate[t] * (Q[t] @ cumsum(K[i] * V[i] / cumsum(K[i])))Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| GLA Block |
| |
| Q, K, V projections |
| | |
| Data-dependent gating (G) |
| | |
| Linear attention with gates |
| | |
| Output projection |
+-------------------------------------+
| (repeat for num_layers)
v
[batch, hidden_size]Complexity
| Aspect | Standard Attention | GLA |
|---|---|---|
| Time | O(L^2) | O(L) |
| Space | O(L^2) | O(L) |
| Hardware | FlashAttention needed | Native tensor ops |
Usage
model = GLA.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
num_heads: 4
)Reference
- Paper: "Gated Linear Attention Transformers with Hardware-Efficient Training"
- Implementation: flash-linear-attention (https://github.com/fla-org/flash-linear-attention)
Summary
Functions
Build a GLA model for sequence processing.
Build the Gated Linear Attention layer.
Build a single GLA block.
Get the output size of a GLA model.
Calculate approximate parameter count for a GLA model.
Recommended default configuration for sequence processing.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:head_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a GLA model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of GLA blocks (default: 6):num_heads- Number of attention heads (default: 4):head_dim- Dimension per head (default: 64):expand_factor- FFN expansion factor (default: 2):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build the Gated Linear Attention layer.
Key components:
- Q, K, V, G projections (G = gate)
- Linear attention with data-dependent gating
- Output projection
Build a single GLA block.
Each block has:
- Gated Linear Attention layer
- Gated FFN (similar to GLU)
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a GLA model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a GLA model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.