Edifice.Attention.RingAttention (Edifice v0.2.0)

Copy Markdown View Source

Ring Attention: chunked attention simulating ring-distributed computation (Liu et al., 2023).

Splits the sequence into chunks and processes attention in a rotating pattern, where each query chunk attends to key/value chunks in a ring communication order. On a single device, this is equivalent to memory-efficient chunked attention but structured as a ring pattern for educational purposes and future distributed scaling.

Architecture

Input [batch, seq_len, embed_dim]
      |
+-----v--------------------+
| Input Projection          |  Dense to hidden_size
+---------------------------+
      |
      v
+-----v--------------------+
| Ring Attention Block x N  |
|                           |
| 1. LayerNorm              |
| 2. QKV projection         |
| 3. Split into num_chunks  |
| 4. Ring attention:        |
|    For each Q chunk:      |
|      attend to all K,V    |
|      chunks in ring order |
| 5. Residual               |
| 6. LayerNorm + FFN        |
| 7. Residual               |
+---------------------------+
      |
      v
+---------------------------+
| Final LayerNorm           |
+---------------------------+
      |
      v
[batch, hidden_size]

Key Insight

Ring attention enables processing sequences much longer than what fits in memory on a single device. The ring pattern naturally maps to distributed settings where each device holds one chunk and passes K,V to the next device in a ring topology.

Usage

model = RingAttention.build(
  embed_dim: 288,
  hidden_size: 256,
  num_heads: 4,
  num_chunks: 4,
  num_layers: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a Ring Attention model.

Get the output size of a Ring Attention model.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_chunks, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Ring Attention model.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :num_chunks - Number of ring chunks to split sequence into (default: 4)
  • :num_layers - Number of ring attention layers (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Sequence length (default: 60)

Returns

An Axon model outputting [batch, hidden_size] from the last timestep.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a Ring Attention model.