Edifice.Blocks.RoPE (Edifice v0.2.0)

Copy Markdown View Source

Rotary Position Embedding (RoPE).

Encodes position information by rotating query and key vectors in pairs of dimensions. This provides relative position awareness without explicit position embeddings, and naturally extrapolates to longer sequences.

How It Works

RoPE rotates each pair of dimensions (2i, 2i+1) by angle theta_i * position:

[cos(m*theta_i)  -sin(m*theta_i)] [q_{2i}  ]
[sin(m*theta_i)   cos(m*theta_i)] [q_{2i+1}]

where m is position and theta_i = base^(-2i/d).

The inner product between rotated Q and K at positions m and n depends only on (m - n), giving relative position sensitivity.

Usage

# Apply RoPE to query and key tensors
{q_rotated, k_rotated} = RoPE.apply_rotary(query, key, seq_len: 128)

# As an Axon layer
rotated = RoPE.layer(input, dim: 64, seq_len: 128)

References

Summary

Functions

Apply rotary position embedding to Q and K tensors.

Build an Axon layer that applies RoPE to the input.

Build precomputed frequency table for RoPE.

Apply YaRN (Yet another RoPE extensioN) frequency scaling.

Functions

apply_rotary(query, key, arg3 \\ [])

Apply rotary position embedding to Q and K tensors.

Parameters

  • query - Query tensor [batch, seq_len, dim]
  • key - Key tensor [batch, seq_len, dim]

Options

  • :seq_len - Sequence length (inferred from tensor if not provided)
  • :base - RoPE base frequency (default: 10000.0)

Returns

{rotated_query, rotated_key} with same shapes as input.

layer(input, opts \\ [])

@spec layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build an Axon layer that applies RoPE to the input.

Options

  • :dim - Feature dimension (required, must be even)
  • :name - Layer name prefix (default: "rope")

precompute_freqs(dim, max_seq_len, opts \\ [])

@spec precompute_freqs(pos_integer(), pos_integer(), keyword()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Build precomputed frequency table for RoPE.

Returns cosine and sine tables of shape [max_seq_len, dim/2].

Options

  • :base - RoPE base frequency (default: 10000.0)
  • :scaling_type - :none (default) or :yarn for YaRN context extension
  • :target_length - Target context length for YaRN scaling (required when :yarn)
  • :original_length - Original trained context length (default: 4096)
  • :beta_fast - YaRN high-frequency boundary (default: 32.0)
  • :beta_slow - YaRN low-frequency boundary (default: 1.0)

yarn_scale_freqs(freqs, dim, opts)

@spec yarn_scale_freqs(Nx.Tensor.t(), pos_integer(), keyword()) :: Nx.Tensor.t()

Apply YaRN (Yet another RoPE extensioN) frequency scaling.

YaRN scales RoPE frequency bands differently based on wavelength:

  • High-frequency (local position info): left unchanged
  • Low-frequency (global position): scaled down by scale = target / original
  • Middle bands: linear interpolation between the two

The boundaries between regions are determined by beta_fast and beta_slow.

Options

  • :target_length - Target context length (required)
  • :original_length - Original trained context length (default: 4096)
  • :beta_fast - High-frequency boundary (default: 32.0)
  • :beta_slow - Low-frequency boundary (default: 1.0)

References