Edifice.Attention.Perceiver (Edifice v0.2.0)

Copy Markdown View Source

Perceiver IO: General-purpose architecture with learned latent array.

Perceiver IO uses cross-attention to map arbitrary inputs to a fixed-size latent array, processes latents with self-attention, then optionally cross-attends back for structured output. This decouples compute from input size.

Key Innovation: Latent Bottleneck

Instead of self-attending over the full input (O(N^2)), Perceiver cross-attends inputs to a small learned latent array (M << N), then self-attends over latents (O(M^2)). Total: O(N*M + M^2).

Input [batch, N, input_dim]     Latents [1, M, latent_dim] (learned)
      |                               |
      +-- Cross-Attention(L, Input) --+
                  |
            Latents' [batch, M, latent_dim]
                  |
            Self-Attention x num_layers
                  |
            Latents'' [batch, M, latent_dim]
                  |
            Pool -> [batch, latent_dim]

Architecture

Input [batch, seq_len, input_dim]
      |
      v
+-------------------------------------+
|  Cross-Attention                     |
|  Q = Latent Array (learned, M x D)  |
|  K, V = Input                       |
|  -> Latents absorb input info       |
+-------------------------------------+
      |
      v (repeat num_cross_layers)
+-------------------------------------+
|  Self-Attention Block                |
|  LayerNorm -> Self-Attn -> Residual  |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat num_layers)
      v
Mean pool over latents -> [batch, latent_dim]

Complexity

ComponentStandard TransformerPerceiver
Self-AttnO(N^2)O(M^2)
Cross-Attn-O(N*M)
TotalO(N^2)O(N*M + M^2)

Where M = num_latents << N = input length.

Usage

model = Perceiver.build(
  input_dim: 287,
  latent_dim: 256,
  num_latents: 64,
  num_layers: 4,
  num_cross_layers: 1,
  num_heads: 4
)

References

  • Paper: "Perceiver IO: A General Architecture for Structured Inputs & Outputs" (Jaegle et al., DeepMind 2021)
  • Original: "Perceiver: General Perception with Iterative Attention" (2021)

Summary

Types

Options for build/1.

Functions

Build a Perceiver IO model for sequence processing.

Build a cross-attention block where latents attend to input.

Build a self-attention block over latents.

Get the output size of a Perceiver model.

Calculate approximate parameter count for a Perceiver model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:input_dim, pos_integer()}
  | {:latent_dim, pos_integer()}
  | {:num_cross_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_latents, pos_integer()}
  | {:num_layers, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Perceiver IO model for sequence processing.

Options

  • :input_dim - Size of input embedding per timestep (required)
  • :latent_dim - Latent array dimension (default: 256)
  • :num_latents - Number of latent vectors M (default: 64)
  • :num_layers - Number of self-attention layers over latents (default: 4)
  • :num_cross_layers - Number of input cross-attention passes (default: 1)
  • :num_heads - Number of attention heads (default: 4)
  • :dropout - Dropout rate (default: 0.1)

Returns

An Axon model that outputs [batch, latent_dim].

build_cross_attention_block(latents, input_kv, opts)

@spec build_cross_attention_block(Axon.t(), Axon.t(), keyword()) :: Axon.t()

Build a cross-attention block where latents attend to input.

Structure: LayerNorm(latents) -> CrossAttn(Q=latents, KV=input) -> Residual

       -> LayerNorm -> FFN -> Residual

build_self_attention_block(input, opts)

@spec build_self_attention_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a self-attention block over latents.

Structure: LayerNorm -> Self-Attention -> Residual -> LayerNorm -> FFN -> Residual

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Perceiver model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Perceiver model.