# `Edifice.Attention.YARN`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/yarn.ex#L1)

YaRN: Yet another RoPE extensioN for context window extension.

<!-- verified: true, date: 2026-02-23 -->

YaRN modifies RoPE frequency bands to handle longer sequences than the model
was originally trained on. It achieves this by scaling different frequency
components based on their wavelength relative to the original context length.

## Key Insight

RoPE encodes position via rotations at different frequencies. For context
extension, high-frequency (local position) info should be preserved while
low-frequency (global position) info needs scaling:

```
High-frequency bands (short wavelength):
  - Capture local positional relationships
  - Left unchanged (factor = 1)

Low-frequency bands (long wavelength):
  - Capture global position in sequence
  - Scaled down by 1/scale factor

Middle bands:
  - Linear interpolation between the two regimes
```

## Formula

For each dimension i, compute a scaling factor:
- If wavelength < high_freq_threshold: factor = 1 (unchanged)
- If wavelength > low_freq_threshold: factor = 1/scale (full scaling)
- Otherwise: linear interpolation between 1 and 1/scale

The thresholds are derived from:
- `high_freq_wavelen = original_max_position / high_freq_factor`
- `low_freq_wavelen = original_max_position / low_freq_factor`

## Usage

    # Build a YaRN-modified RoPE layer
    model = YARN.build(
      embed_dim: 64,
      scale: 8,
      original_max_position: 2048
    )

    # Get the frequency table directly
    freqs = YARN.yarn_freqs(64,
      scale: 8,
      original_max_position: 2048
    )

    # Apply YaRN to query/key tensors
    {q_rotated, k_rotated} = YARN.apply_yarn(q, k,
      scale: 8,
      original_max_position: 2048
    )

## References

- "YaRN: Efficient Context Window Extension of Large Language Models"
  (Peng et al., 2023) — https://arxiv.org/abs/2309.00071

# `yarn_opt`

```elixir
@type yarn_opt() ::
  {:embed_dim, pos_integer()}
  | {:scale, number()}
  | {:original_max_position, pos_integer()}
  | {:low_freq_factor, number()}
  | {:high_freq_factor, number()}
  | {:base, number()}
  | {:name, String.t()}
```

Options for YaRN functions.

# `apply_yarn`

```elixir
@spec apply_yarn(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}
```

Apply YaRN-modified RoPE to query and key tensors.

## Parameters

  - `query` - Query tensor `[batch, seq_len, embed_dim]`
  - `key` - Key tensor `[batch, seq_len, embed_dim]`
  - `opts` - Options (see `yarn_freqs/2` for available options)

## Returns

  `{rotated_query, rotated_key}` with same shapes as input.

## Example

    {q_rot, k_rot} = YARN.apply_yarn(query, key, scale: 8)

# `build`

```elixir
@spec build([yarn_opt()]) :: Axon.t()
```

Build an Axon model that applies YaRN-modified RoPE to input.

## Options

  - `:embed_dim` - Feature dimension (required, must be even)
  - `:scale` - Context extension scale factor (default: 8)
    For example, scale=8 extends 2048 to 16384 context length
  - `:original_max_position` - Original trained context length (default: 2048)
  - `:low_freq_factor` - Low frequency boundary factor (default: 1)
  - `:high_freq_factor` - High frequency boundary factor (default: 4)
  - `:base` - RoPE base frequency (default: 10000.0)
  - `:name` - Layer name prefix (default: "yarn")

## Returns

  An Axon model that applies YaRN-modified RoPE to the input tensor.

# `effective_context_length`

```elixir
@spec effective_context_length(pos_integer(), number()) :: number()
```

Calculate the effective context length after YaRN scaling.

## Example

    YARN.effective_context_length(2048, 8)
    # => 16384

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for YaRN.

# `yarn_freqs`

```elixir
@spec yarn_freqs(
  pos_integer(),
  keyword()
) :: Nx.Tensor.t()
```

Compute YaRN-scaled frequency table.

Returns a tensor of shape `[embed_dim / 2]` containing the scaled frequencies
for each dimension pair.

## Options

  - `:scale` - Context extension scale factor (default: 8)
  - `:original_max_position` - Original trained context length (default: 2048)
  - `:low_freq_factor` - Low frequency boundary factor (default: 1)
  - `:high_freq_factor` - High frequency boundary factor (default: 4)
  - `:base` - RoPE base frequency (default: 10000.0)

## Example

    freqs = YARN.yarn_freqs(64, scale: 8, original_max_position: 2048)
    # => Tensor of shape {32} with scaled frequencies

---

*Consult [api-reference.md](api-reference.md) for complete listing*
