Rotary Position Embedding (RoPE).
Encodes position information by rotating query and key vectors in pairs of dimensions. This provides relative position awareness without explicit position embeddings, and naturally extrapolates to longer sequences.
How It Works
RoPE rotates each pair of dimensions (2i, 2i+1) by angle theta_i * position:
[cos(m*theta_i) -sin(m*theta_i)] [q_{2i} ]
[sin(m*theta_i) cos(m*theta_i)] [q_{2i+1}]where m is position and theta_i = base^(-2i/d).
The inner product between rotated Q and K at positions m and n depends only on (m - n), giving relative position sensitivity.
Usage
# Apply RoPE to query and key tensors
{q_rotated, k_rotated} = RoPE.apply_rotary(query, key, seq_len: 128)
# As an Axon layer
rotated = RoPE.layer(input, dim: 64, seq_len: 128)References
- "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021)
- https://arxiv.org/abs/2104.09864
Summary
Functions
Apply rotary position embedding to Q and K tensors.
Build an Axon layer that applies RoPE to the input.
Build precomputed frequency table for RoPE.
Apply YaRN (Yet another RoPE extensioN) frequency scaling.
Functions
Apply rotary position embedding to Q and K tensors.
Parameters
query- Query tensor [batch, seq_len, dim]key- Key tensor [batch, seq_len, dim]
Options
:seq_len- Sequence length (inferred from tensor if not provided):base- RoPE base frequency (default: 10000.0)
Returns
{rotated_query, rotated_key} with same shapes as input.
Build an Axon layer that applies RoPE to the input.
Options
:dim- Feature dimension (required, must be even):name- Layer name prefix (default: "rope")
@spec precompute_freqs(pos_integer(), pos_integer(), keyword()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Build precomputed frequency table for RoPE.
Returns cosine and sine tables of shape [max_seq_len, dim/2].
Options
:base- RoPE base frequency (default: 10000.0):scaling_type-:none(default) or:yarnfor YaRN context extension:target_length- Target context length for YaRN scaling (required when:yarn):original_length- Original trained context length (default: 4096):beta_fast- YaRN high-frequency boundary (default: 32.0):beta_slow- YaRN low-frequency boundary (default: 1.0)
@spec yarn_scale_freqs(Nx.Tensor.t(), pos_integer(), keyword()) :: Nx.Tensor.t()
Apply YaRN (Yet another RoPE extensioN) frequency scaling.
YaRN scales RoPE frequency bands differently based on wavelength:
- High-frequency (local position info): left unchanged
- Low-frequency (global position): scaled down by
scale = target / original - Middle bands: linear interpolation between the two
The boundaries between regions are determined by beta_fast and beta_slow.
Options
:target_length- Target context length (required):original_length- Original trained context length (default: 4096):beta_fast- High-frequency boundary (default: 32.0):beta_slow- Low-frequency boundary (default: 1.0)
References
- "YaRN: Efficient Context Window Extension of Large Language Models" (Peng et al., 2023) — https://arxiv.org/abs/2309.00071