YaRN: Yet another RoPE extensioN for context window extension.
YaRN modifies RoPE frequency bands to handle longer sequences than the model was originally trained on. It achieves this by scaling different frequency components based on their wavelength relative to the original context length.
Key Insight
RoPE encodes position via rotations at different frequencies. For context extension, high-frequency (local position) info should be preserved while low-frequency (global position) info needs scaling:
High-frequency bands (short wavelength):
- Capture local positional relationships
- Left unchanged (factor = 1)
Low-frequency bands (long wavelength):
- Capture global position in sequence
- Scaled down by 1/scale factor
Middle bands:
- Linear interpolation between the two regimesFormula
For each dimension i, compute a scaling factor:
- If wavelength < high_freq_threshold: factor = 1 (unchanged)
- If wavelength > low_freq_threshold: factor = 1/scale (full scaling)
- Otherwise: linear interpolation between 1 and 1/scale
The thresholds are derived from:
high_freq_wavelen = original_max_position / high_freq_factorlow_freq_wavelen = original_max_position / low_freq_factor
Usage
# Build a YaRN-modified RoPE layer
model = YARN.build(
embed_dim: 64,
scale: 8,
original_max_position: 2048
)
# Get the frequency table directly
freqs = YARN.yarn_freqs(64,
scale: 8,
original_max_position: 2048
)
# Apply YaRN to query/key tensors
{q_rotated, k_rotated} = YARN.apply_yarn(q, k,
scale: 8,
original_max_position: 2048
)References
- "YaRN: Efficient Context Window Extension of Large Language Models" (Peng et al., 2023) — https://arxiv.org/abs/2309.00071
Summary
Types
Options for YaRN functions.
Functions
Apply YaRN-modified RoPE to query and key tensors.
Build an Axon model that applies YaRN-modified RoPE to input.
Calculate the effective context length after YaRN scaling.
Get recommended defaults for YaRN.
Compute YaRN-scaled frequency table.
Types
@type yarn_opt() :: {:embed_dim, pos_integer()} | {:scale, number()} | {:original_max_position, pos_integer()} | {:low_freq_factor, number()} | {:high_freq_factor, number()} | {:base, number()} | {:name, String.t()}
Options for YaRN functions.
Functions
@spec apply_yarn(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Apply YaRN-modified RoPE to query and key tensors.
Parameters
query- Query tensor[batch, seq_len, embed_dim]key- Key tensor[batch, seq_len, embed_dim]opts- Options (seeyarn_freqs/2for available options)
Returns
{rotated_query, rotated_key} with same shapes as input.
Example
{q_rot, k_rot} = YARN.apply_yarn(query, key, scale: 8)
Build an Axon model that applies YaRN-modified RoPE to input.
Options
:embed_dim- Feature dimension (required, must be even):scale- Context extension scale factor (default: 8) For example, scale=8 extends 2048 to 16384 context length:original_max_position- Original trained context length (default: 2048):low_freq_factor- Low frequency boundary factor (default: 1):high_freq_factor- High frequency boundary factor (default: 4):base- RoPE base frequency (default: 10000.0):name- Layer name prefix (default: "yarn")
Returns
An Axon model that applies YaRN-modified RoPE to the input tensor.
@spec effective_context_length(pos_integer(), number()) :: number()
Calculate the effective context length after YaRN scaling.
Example
YARN.effective_context_length(2048, 8)
# => 16384
@spec recommended_defaults() :: keyword()
Get recommended defaults for YaRN.
@spec yarn_freqs( pos_integer(), keyword() ) :: Nx.Tensor.t()
Compute YaRN-scaled frequency table.
Returns a tensor of shape [embed_dim / 2] containing the scaled frequencies
for each dimension pair.
Options
:scale- Context extension scale factor (default: 8):original_max_position- Original trained context length (default: 2048):low_freq_factor- Low frequency boundary factor (default: 1):high_freq_factor- High frequency boundary factor (default: 4):base- RoPE base frequency (default: 10000.0)
Example
freqs = YARN.yarn_freqs(64, scale: 8, original_max_position: 2048)
# => Tensor of shape {32} with scaled frequencies