Edifice.Attention.Nystromformer (Edifice v0.2.0)

Copy Markdown View Source

Nystromformer: Nystrom-based approximation for O(N) attention.

Approximates the full softmax attention matrix using the Nystrom method with landmark points. Instead of computing the full N x N attention matrix, it samples M landmark points and reconstructs the attention through them.

Key Innovation: Nystrom Approximation

The Nystrom method approximates a matrix using a subset of its columns/rows:

Full attention:    A = softmax(QK^T / sqrt(d))
Nystrom approx:    A ~ F1 * pinv(F2) * F3

Where:
  landmarks = downsample(K, M)     # M landmark points
  F1 = softmax(Q @ landmarks^T)    # [N, M] queries-to-landmarks
  F2 = softmax(landmarks @ landmarks^T)  # [M, M] landmarks-to-landmarks
  F3 = softmax(landmarks @ K^T)    # [M, N] landmarks-to-keys

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  Nystromformer Block                 |
|                                      |
|  LayerNorm                           |
|    -> Q, K, V projections            |
|    -> Select M landmarks (avg pool)  |
|    -> Q-to-landmark attention [N,M]  |
|    -> Landmark kernel [M,M]          |
|    -> Landmark-to-K attention [M,N]  |
|    -> Reconstruct: F1*F2^{-1}*F3*V  |
|  -> Residual                         |
|                                      |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
Last timestep -> [batch, hidden_size]

Complexity

ComponentStandardNystromformer
AttentionO(N^2)O(N*M)
MemoryO(N^2)O(N*M + M^2)
Kernel inv-O(M^3)

Where M = num_landmarks << N. Typically M = 32-64 is sufficient.

Usage

model = Nystromformer.build(
  embed_dim: 287,
  hidden_size: 256,
  num_landmarks: 32,
  num_layers: 4,
  num_heads: 4
)

References

  • Paper: "Nystromformer: A Nystrom-Based Algorithm for Approximating Self-Attention" (Xiong et al., AAAI 2021)

Summary

Types

Options for build/1.

Functions

Build a Nystromformer model for sequence processing.

Get the output size of a Nystromformer model.

Calculate approximate parameter count for a Nystromformer model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_landmarks, pos_integer()}
  | {:num_layers, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Nystromformer model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_landmarks - Number of Nystrom landmark points M (default: 32)
  • :num_layers - Number of Nystromformer blocks (default: 4)
  • :num_heads - Number of attention heads (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Nystromformer model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Nystromformer model.