Edifice.Attention.Performer (Edifice v0.2.0)

Copy Markdown View Source

Performer: Fast Attention Via Positive Orthogonal Random Features (FAVOR+).

Performer approximates softmax attention using random feature maps, achieving O(N) time and space complexity. The FAVOR+ mechanism uses orthogonal random features to approximate the exponential kernel.

Key Innovation: FAVOR+ Random Feature Attention

Standard attention: softmax(QK^T/sqrt(d)) * V -- O(N^2)

FAVOR+ approximates exp(QK^T) using random features:

exp(q^T k) ~ phi(q)^T phi(k)

Where phi(x) = exp(-||x||^2 / 2) / sqrt(m) * [exp(w_1^T x), ..., exp(w_m^T x)]
w_1, ..., w_m ~ iid N(0, I) (orthogonalized)

This allows rewriting attention as:

Attn(Q,K,V) = D^{-1} * phi(Q) * (phi(K)^T * V)
D = diag(phi(Q) * phi(K)^T * 1)

Computing phi(K)^T V is O(Ndm) instead of O(N^2d).

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  Performer Block                     |
|                                      |
|  LayerNorm                           |
|    -> Q, K, V projections            |
|    -> Random feature map phi(Q,K)    |
|       (orthogonal random features)   |
|    -> KV = phi(K)^T * V   [d, d]    |
|    -> out = phi(Q) * KV    [N, d]   |
|    -> normalize by D                 |
|  -> Residual                         |
|                                      |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
Last timestep -> [batch, hidden_size]

Complexity

ComponentStandardPerformer
TimeO(N^2 * d)O(N d m)
SpaceO(N^2 + N*d)O(N * (d+m))
Random features-m (default 64)

Where m = num_features controls approximation quality vs speed tradeoff.

Usage

model = Performer.build(
  embed_dim: 287,
  hidden_size: 256,
  num_features: 64,
  num_layers: 4,
  dropout: 0.1
)

References

  • Paper: "Rethinking Attention with Performers" (Choromanski et al., ICLR 2021)
  • FAVOR+: Fast Attention Via positive Orthogonal Random features

Summary

Types

Options for build/1.

Functions

Build a Performer model for sequence processing.

Generate orthogonal random features for FAVOR+ via QR decomposition.

Get the output size of a Performer model.

Calculate approximate parameter count for a Performer model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:num_features, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Performer model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_features - Number of random features m for FAVOR+ (default: 64)
  • :num_layers - Number of Performer blocks (default: 4)
  • :num_heads - Number of attention heads (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

generate_orthogonal_features(head_dim, num_features, opts \\ [])

@spec generate_orthogonal_features(pos_integer(), pos_integer(), keyword()) ::
  Nx.Tensor.t()

Generate orthogonal random features for FAVOR+ via QR decomposition.

Returns a [num_features, head_dim] matrix with orthogonal rows (within blocks of size head_dim). Multiple orthogonal blocks are concatenated if num_features > head_dim.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Performer model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Performer model.