Nystromformer: Nystrom-based approximation for O(N) attention.
Approximates the full softmax attention matrix using the Nystrom method with landmark points. Instead of computing the full N x N attention matrix, it samples M landmark points and reconstructs the attention through them.
Key Innovation: Nystrom Approximation
The Nystrom method approximates a matrix using a subset of its columns/rows:
Full attention: A = softmax(QK^T / sqrt(d))
Nystrom approx: A ~ F1 * pinv(F2) * F3
Where:
landmarks = downsample(K, M) # M landmark points
F1 = softmax(Q @ landmarks^T) # [N, M] queries-to-landmarks
F2 = softmax(landmarks @ landmarks^T) # [M, M] landmarks-to-landmarks
F3 = softmax(landmarks @ K^T) # [M, N] landmarks-to-keysArchitecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| Nystromformer Block |
| |
| LayerNorm |
| -> Q, K, V projections |
| -> Select M landmarks (avg pool) |
| -> Q-to-landmark attention [N,M] |
| -> Landmark kernel [M,M] |
| -> Landmark-to-K attention [M,N] |
| -> Reconstruct: F1*F2^{-1}*F3*V |
| -> Residual |
| |
| LayerNorm -> FFN -> Residual |
+-------------------------------------+
| (repeat for num_layers)
v
Last timestep -> [batch, hidden_size]Complexity
| Component | Standard | Nystromformer |
|---|---|---|
| Attention | O(N^2) | O(N*M) |
| Memory | O(N^2) | O(N*M + M^2) |
| Kernel inv | - | O(M^3) |
Where M = num_landmarks << N. Typically M = 32-64 is sufficient.
Usage
model = Nystromformer.build(
embed_dim: 287,
hidden_size: 256,
num_landmarks: 32,
num_layers: 4,
num_heads: 4
)References
- Paper: "Nystromformer: A Nystrom-Based Algorithm for Approximating Self-Attention" (Xiong et al., AAAI 2021)
Summary
Functions
Build a Nystromformer model for sequence processing.
Get the output size of a Nystromformer model.
Calculate approximate parameter count for a Nystromformer model.
Recommended default configuration for sequence processing.
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_landmarks, pos_integer()} | {:num_layers, pos_integer()}
Options for build/1.
Functions
Build a Nystromformer model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_landmarks- Number of Nystrom landmark points M (default: 32):num_layers- Number of Nystromformer blocks (default: 4):num_heads- Number of attention heads (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Nystromformer model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Nystromformer model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.