Ring Attention: chunked attention simulating ring-distributed computation (Liu et al., 2023).
Splits the sequence into chunks and processes attention in a rotating pattern, where each query chunk attends to key/value chunks in a ring communication order. On a single device, this is equivalent to memory-efficient chunked attention but structured as a ring pattern for educational purposes and future distributed scaling.
Architecture
Input [batch, seq_len, embed_dim]
|
+-----v--------------------+
| Input Projection | Dense to hidden_size
+---------------------------+
|
v
+-----v--------------------+
| Ring Attention Block x N |
| |
| 1. LayerNorm |
| 2. QKV projection |
| 3. Split into num_chunks |
| 4. Ring attention: |
| For each Q chunk: |
| attend to all K,V |
| chunks in ring order |
| 5. Residual |
| 6. LayerNorm + FFN |
| 7. Residual |
+---------------------------+
|
v
+---------------------------+
| Final LayerNorm |
+---------------------------+
|
v
[batch, hidden_size]Key Insight
Ring attention enables processing sequences much longer than what fits in memory on a single device. The ring pattern naturally maps to distributed settings where each device holds one chunk and passes K,V to the next device in a ring topology.
Usage
model = RingAttention.build(
embed_dim: 288,
hidden_size: 256,
num_heads: 4,
num_chunks: 4,
num_layers: 4
)References
- Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context" (2023)
- https://arxiv.org/abs/2310.01889
Summary
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_chunks, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Ring Attention model.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):num_chunks- Number of ring chunks to split sequence into (default: 4):num_layers- Number of ring attention layers (default: 4):dropout- Dropout rate (default: 0.1):window_size- Sequence length (default: 60)
Returns
An Axon model outputting [batch, hidden_size] from the last timestep.
@spec output_size(keyword()) :: pos_integer()
Get the output size of a Ring Attention model.