Edifice.Attention.RetNet (Edifice v0.2.0)

Copy Markdown View Source

RetNet: Retentive Network - A Successor to Transformer.

Implements the RetNet architecture from "Retentive Network: A Successor to Transformer for Large Language Models" (Sun et al., Microsoft 2023).

Key Innovation: Retention Mechanism

RetNet replaces attention with "retention" - a decay-based mechanism:

Parallel:   Y = (Q . Theta) . D . (K . Theta)^T . V
Recurrent:  s_n = gamma*s_{n-1} + K_n^T*V_n; o_n = Q_n*s_n

Where D is a decay matrix: D[n,m] = gamma^(n-m) if n>=m, else 0.

Triple Paradigm

The same weights support three computation modes:

  • Parallel: Training mode, O(L^2) but GPU-parallel
  • Recurrent: Inference mode, O(1) per token
  • Chunkwise: Long sequences, O(L) with chunking

Multi-Scale Retention (MSR)

Different heads use different decay rates for multi-scale modeling:

  • gamma_h = 1 - 2^(-5-h) for head h
  • GroupNorm instead of LayerNorm (handles different head variances)
  • SiLU gating: Y = SiLU(XW_G) . Retention(X)W_O

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       RetNet Block                   |
|  LayerNorm -> MSR -> Residual        |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]

Usage

# Build RetNet backbone
model = RetNet.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  num_heads: 4
)

Comparison

ModeTimeMemoryBest For
ParallelO(L^2)O(L^2)Training
RecurrentO(1)O(1)Inference
ChunkwiseO(L)O(C)Long sequences

References

Summary

Types

Options for build/1.

Functions

Build a RetNet model for sequence processing.

Build Multi-Scale Retention layer.

Build a single RetNet block.

Default dropout rate

Default feedforward expansion factor

Default hidden dimension

Default number of retention heads

Default number of layers

Epsilon for numerical stability

Initialize retention state for recurrent inference.

Get the output size of a RetNet model.

Calculate approximate parameter count for a RetNet model.

Recommended default configuration for sequence processing.

Build recurrent retention state update.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a RetNet model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of RetNet blocks (default: 6)
  • :num_heads - Number of retention heads (default: 4)
  • :expand_factor - FFN expansion factor (default: 2)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)
  • :mode - Computation mode: :parallel, :recurrent, :chunkwise (default: :parallel)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_multi_scale_retention(input, opts \\ [])

@spec build_multi_scale_retention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build Multi-Scale Retention layer.

MSR uses different decay rates (gamma) per head for multi-scale modeling:

  • gamma_h = 1 - 2^(-5-h) for head h
  • SiLU gating: Y = SiLU(XW_G) . Retention(X)W_O
  • GroupNorm for handling different head variances

build_retnet_block(input, opts \\ [])

@spec build_retnet_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single RetNet block.

RetNet block structure:

  1. LayerNorm -> Multi-Scale Retention -> Residual
  2. LayerNorm -> FFN -> Residual

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_expand_factor()

@spec default_expand_factor() :: pos_integer()

Default feedforward expansion factor

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_heads()

@spec default_num_heads() :: pos_integer()

Default number of retention heads

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

eps()

@spec eps() :: float()

Epsilon for numerical stability

init_retention_state(batch_size, num_heads, head_dim)

@spec init_retention_state(pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()

Initialize retention state for recurrent inference.

Returns a zero-initialized state tensor of shape [batch, heads, head_dim, head_dim].

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a RetNet model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a RetNet model.

recurrent_retention_step(q, k, v, state, gamma)

@spec recurrent_retention_step(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Build recurrent retention state update.

Recurrent formulation for O(1) inference:

  • sn = gamma * s{n-1} + K_n^T * V_n
  • o_n = Q_n * s_n

This is used during inference when processing one token at a time.