RetNet: Retentive Network - A Successor to Transformer.
Implements the RetNet architecture from "Retentive Network: A Successor to Transformer for Large Language Models" (Sun et al., Microsoft 2023).
Key Innovation: Retention Mechanism
RetNet replaces attention with "retention" - a decay-based mechanism:
Parallel: Y = (Q . Theta) . D . (K . Theta)^T . V
Recurrent: s_n = gamma*s_{n-1} + K_n^T*V_n; o_n = Q_n*s_nWhere D is a decay matrix: D[n,m] = gamma^(n-m) if n>=m, else 0.
Triple Paradigm
The same weights support three computation modes:
- Parallel: Training mode, O(L^2) but GPU-parallel
- Recurrent: Inference mode, O(1) per token
- Chunkwise: Long sequences, O(L) with chunking
Multi-Scale Retention (MSR)
Different heads use different decay rates for multi-scale modeling:
- gamma_h = 1 - 2^(-5-h) for head h
- GroupNorm instead of LayerNorm (handles different head variances)
- SiLU gating: Y = SiLU(XW_G) . Retention(X)W_O
Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| RetNet Block |
| LayerNorm -> MSR -> Residual |
| LayerNorm -> FFN -> Residual |
+-------------------------------------+
| (repeat for num_layers)
v
Output [batch, hidden_size]Usage
# Build RetNet backbone
model = RetNet.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
num_heads: 4
)Comparison
| Mode | Time | Memory | Best For |
|---|---|---|---|
| Parallel | O(L^2) | O(L^2) | Training |
| Recurrent | O(1) | O(1) | Inference |
| Chunkwise | O(L) | O(C) | Long sequences |
References
Summary
Functions
Build a RetNet model for sequence processing.
Build Multi-Scale Retention layer.
Build a single RetNet block.
Default dropout rate
Default feedforward expansion factor
Default hidden dimension
Default number of retention heads
Default number of layers
Epsilon for numerical stability
Initialize retention state for recurrent inference.
Get the output size of a RetNet model.
Calculate approximate parameter count for a RetNet model.
Recommended default configuration for sequence processing.
Build recurrent retention state update.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:expand_factor, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a RetNet model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of RetNet blocks (default: 6):num_heads- Number of retention heads (default: 4):expand_factor- FFN expansion factor (default: 2):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60):mode- Computation mode: :parallel, :recurrent, :chunkwise (default: :parallel)
Returns
An Axon model that processes sequences and outputs the last hidden state.
Build Multi-Scale Retention layer.
MSR uses different decay rates (gamma) per head for multi-scale modeling:
- gamma_h = 1 - 2^(-5-h) for head h
- SiLU gating: Y = SiLU(XW_G) . Retention(X)W_O
- GroupNorm for handling different head variances
Build a single RetNet block.
RetNet block structure:
- LayerNorm -> Multi-Scale Retention -> Residual
- LayerNorm -> FFN -> Residual
@spec default_dropout() :: float()
Default dropout rate
@spec default_expand_factor() :: pos_integer()
Default feedforward expansion factor
@spec default_num_heads() :: pos_integer()
Default number of retention heads
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec eps() :: float()
Epsilon for numerical stability
@spec init_retention_state(pos_integer(), pos_integer(), pos_integer()) :: Nx.Tensor.t()
Initialize retention state for recurrent inference.
Returns a zero-initialized state tensor of shape [batch, heads, head_dim, head_dim].
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a RetNet model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a RetNet model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.
@spec recurrent_retention_step( Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() ) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Build recurrent retention state update.
Recurrent formulation for O(1) inference:
- sn = gamma * s{n-1} + K_n^T * V_n
- o_n = Q_n * s_n
This is used during inference when processing one token at a time.