Edifice.Attention.YARN (Edifice v0.2.0)

Copy Markdown View Source

YaRN: Yet another RoPE extensioN for context window extension.

YaRN modifies RoPE frequency bands to handle longer sequences than the model was originally trained on. It achieves this by scaling different frequency components based on their wavelength relative to the original context length.

Key Insight

RoPE encodes position via rotations at different frequencies. For context extension, high-frequency (local position) info should be preserved while low-frequency (global position) info needs scaling:

High-frequency bands (short wavelength):
  - Capture local positional relationships
  - Left unchanged (factor = 1)

Low-frequency bands (long wavelength):
  - Capture global position in sequence
  - Scaled down by 1/scale factor

Middle bands:
  - Linear interpolation between the two regimes

Formula

For each dimension i, compute a scaling factor:

  • If wavelength < high_freq_threshold: factor = 1 (unchanged)
  • If wavelength > low_freq_threshold: factor = 1/scale (full scaling)
  • Otherwise: linear interpolation between 1 and 1/scale

The thresholds are derived from:

  • high_freq_wavelen = original_max_position / high_freq_factor
  • low_freq_wavelen = original_max_position / low_freq_factor

Usage

# Build a YaRN-modified RoPE layer
model = YARN.build(
  embed_dim: 64,
  scale: 8,
  original_max_position: 2048
)

# Get the frequency table directly
freqs = YARN.yarn_freqs(64,
  scale: 8,
  original_max_position: 2048
)

# Apply YaRN to query/key tensors
{q_rotated, k_rotated} = YARN.apply_yarn(q, k,
  scale: 8,
  original_max_position: 2048
)

References

Summary

Types

Options for YaRN functions.

Functions

Apply YaRN-modified RoPE to query and key tensors.

Build an Axon model that applies YaRN-modified RoPE to input.

Calculate the effective context length after YaRN scaling.

Get recommended defaults for YaRN.

Compute YaRN-scaled frequency table.

Types

yarn_opt()

@type yarn_opt() ::
  {:embed_dim, pos_integer()}
  | {:scale, number()}
  | {:original_max_position, pos_integer()}
  | {:low_freq_factor, number()}
  | {:high_freq_factor, number()}
  | {:base, number()}
  | {:name, String.t()}

Options for YaRN functions.

Functions

apply_yarn(query, key, opts \\ [])

@spec apply_yarn(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Apply YaRN-modified RoPE to query and key tensors.

Parameters

  • query - Query tensor [batch, seq_len, embed_dim]
  • key - Key tensor [batch, seq_len, embed_dim]
  • opts - Options (see yarn_freqs/2 for available options)

Returns

{rotated_query, rotated_key} with same shapes as input.

Example

{q_rot, k_rot} = YARN.apply_yarn(query, key, scale: 8)

build(opts \\ [])

@spec build([yarn_opt()]) :: Axon.t()

Build an Axon model that applies YaRN-modified RoPE to input.

Options

  • :embed_dim - Feature dimension (required, must be even)
  • :scale - Context extension scale factor (default: 8) For example, scale=8 extends 2048 to 16384 context length
  • :original_max_position - Original trained context length (default: 2048)
  • :low_freq_factor - Low frequency boundary factor (default: 1)
  • :high_freq_factor - High frequency boundary factor (default: 4)
  • :base - RoPE base frequency (default: 10000.0)
  • :name - Layer name prefix (default: "yarn")

Returns

An Axon model that applies YaRN-modified RoPE to the input tensor.

effective_context_length(original_max_position, scale)

@spec effective_context_length(pos_integer(), number()) :: number()

Calculate the effective context length after YaRN scaling.

Example

YARN.effective_context_length(2048, 8)
# => 16384

yarn_freqs(embed_dim, opts \\ [])

@spec yarn_freqs(
  pos_integer(),
  keyword()
) :: Nx.Tensor.t()

Compute YaRN-scaled frequency table.

Returns a tensor of shape [embed_dim / 2] containing the scaled frequencies for each dimension pair.

Options

  • :scale - Context extension scale factor (default: 8)
  • :original_max_position - Original trained context length (default: 2048)
  • :low_freq_factor - Low frequency boundary factor (default: 1)
  • :high_freq_factor - High frequency boundary factor (default: 4)
  • :base - RoPE base frequency (default: 10000.0)

Example

freqs = YARN.yarn_freqs(64, scale: 8, original_max_position: 2048)
# => Tensor of shape {32} with scaled frequencies