Gated DeltaNet - Linear Attention with Gated Delta Rule.
Extends DeltaNet with a data-dependent gating mechanism that modulates the state matrix between timesteps. Where vanilla DeltaNet always retains all of S_{t-1} (modulated only by the delta correction), Gated DeltaNet introduces a forget gate alpha_t that controls how much of the previous state to retain before applying the delta update.
This gives the model explicit control over memory erasure, which is critical for tasks that require forgetting stale associations.
Key Innovations
- Gated state transition: St = alpha_t * S{t-1} + betat * (v_t - S{t-1} k_t) k_t^T
- Data-dependent forgetting: alpha_t = sigmoid(W_alpha x_t) controls memory decay
- Short convolution: Optional causal convolution before Q/K/V projections for local context
- Swish gate on output: Gated output projection for expressivity
Equations
q_t = W_q x_t # Query projection
k_t = W_k x_t # Key projection (L2 normalized)
v_t = W_v x_t # Value projection
beta_t = sigmoid(W_beta x_t) # Update gate (write strength)
alpha_t = sigmoid(W_alpha x_t) # Forget gate (retention)
S_t = alpha_t * S_{t-1} + beta_t * (v_t - S_{t-1} k_t) * k_t^T # Gated delta rule
o_t = swish(W_g x_t) * (S_t q_t) # Gated outputArchitecture
Input [batch, seq_len, embed_dim]
|
v
[Input Projection] -> hidden_size
|
v
+----------------------------------------------+
| Gated DeltaNet Layer |
| Short Conv (optional) for local context |
| Project to Q, K, V, beta, alpha, gate |
| For each timestep: |
| S = alpha * S + beta * (v - S@k) * k^T |
| output = swish(gate) * (S @ q) |
+----------------------------------------------+
| (repeat num_layers)
v
[Layer Norm] -> [Last Timestep]
|
v
Output [batch, hidden_size]Compared to DeltaNet
| Aspect | DeltaNet | Gated DeltaNet |
|---|---|---|
| State update | S + beta error k^T | alpha S + beta error * k^T |
| Forgetting | Implicit (via delta correction) | Explicit (alpha gate) |
| Output gating | None | Swish gate |
| Local context | None | Optional short convolution |
| Expressivity | Lower | Higher (data-dependent dynamics) |
Usage
model = GatedDeltaNet.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
use_short_conv: true,
dropout: 0.1
)References
- "Gated Delta Networks: Improving Mamba2 with Delta Rule" (Yang et al., 2024)
- https://arxiv.org/abs/2412.06464
- Adopted by Qwen3-Next and Kimi Linear (Moonshot AI)
Summary
Functions
Build a Gated DeltaNet model for sequence processing.
Build a single Gated DeltaNet block that can be used as a backbone layer in hybrid architectures.
Default dropout rate
Default hidden dimension
Default number of attention heads
Default number of layers
Get the output size of a Gated DeltaNet model.
Types
@type build_opt() :: {:conv_size, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:use_short_conv, boolean()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Gated DeltaNet model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of independent gated delta rule heads (default: 4):num_layers- Number of Gated DeltaNet layers (default: 4):dropout- Dropout rate between layers (default: 0.1):use_short_conv- Use short causal convolution before projections (default: true):conv_size- Kernel size for short convolution (default: 4):window_size- Expected sequence length (default: 60):seq_len- Alias for window_size
Returns
An Axon model that processes sequences and outputs the last hidden state.
Build a single Gated DeltaNet block that can be used as a backbone layer in hybrid architectures.
Takes input of shape [batch, seq_len, hidden_size] and returns the same shape. Includes pre-norm and residual connection.
Options
:hidden_size- Hidden dimension (default: 256):num_heads- Number of heads (default: 4):use_short_conv- Use short causal convolution (default: true):conv_size- Convolution kernel size (default: 4):dropout- Dropout rate (default: 0.1):name- Layer name prefix (default: "gated_delta_net_block")
@spec default_dropout() :: float()
Default dropout rate
@spec default_num_heads() :: pos_integer()
Default number of attention heads
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Gated DeltaNet model.