Edifice.Attention.GLA (Edifice v0.2.0)

Copy Markdown View Source

GLA: Gated Linear Attention with data-dependent gating.

GLA combines the efficiency of linear attention (O(L) complexity) with data-dependent gating for improved expressiveness. It's particularly effective on short sequences (<2K tokens) where it can outperform FlashAttention-2.

Key Innovation: Data-Dependent Gating

Unlike standard linear attention which uses fixed feature maps, GLA computes gates from the input that control information flow:

output[t] = gate[t] * (Q[t] @ cumsum(K[i] * V[i] / cumsum(K[i])))

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  GLA Block                           |
|                                      |
|  Q, K, V projections                 |
|         |                            |
|  Data-dependent gating (G)           |
|         |                            |
|  Linear attention with gates         |
|         |                            |
|  Output projection                   |
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]

Complexity

AspectStandard AttentionGLA
TimeO(L^2)O(L)
SpaceO(L^2)O(L)
HardwareFlashAttention neededNative tensor ops

Usage

model = GLA.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  num_heads: 4
)

Reference

Summary

Types

Options for build/1.

Functions

Build a GLA model for sequence processing.

Build the Gated Linear Attention layer.

Build a single GLA block.

Get the output size of a GLA model.

Calculate approximate parameter count for a GLA model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a GLA model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of GLA blocks (default: 6)
  • :num_heads - Number of attention heads (default: 4)
  • :head_dim - Dimension per head (default: 64)
  • :expand_factor - FFN expansion factor (default: 2)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_gated_linear_attention(input, opts)

@spec build_gated_linear_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Gated Linear Attention layer.

Key components:

  1. Q, K, V, G projections (G = gate)
  2. Linear attention with data-dependent gating
  3. Output projection

build_gla_block(input, opts)

@spec build_gla_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single GLA block.

Each block has:

  1. Gated Linear Attention layer
  2. Gated FFN (similar to GLU)

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a GLA model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a GLA model.