DeltaNet - Linear Attention with Delta Rule.
Implements linear attention with the delta rule update from "Linear Transformers with Learnable Kernel Functions are Better In-Context Models" (Schlag et al., 2021) and subsequent work.
DeltaNet maintains an associative memory matrix S that is updated using the delta rule, which corrects previous associations rather than blindly accumulating them. This gives it superior retrieval accuracy compared to standard linear attention.
Key Innovations
- Delta rule update: St = S{t-1} + betat * (v_t - S{t-1} k_t) k_t^T
- Error-correcting: Subtracts the current retrieval S_{t-1} k_t before adding
- Learnable beta: Controls update rate per-token via a gate
- Linear complexity: O(d^2) memory vs O(n*d) for softmax attention
Equations
q_t = W_q x_t # Query projection
k_t = W_k x_t # Key projection (L2 normalized)
v_t = W_v x_t # Value projection
beta_t = sigmoid(W_beta x_t) # Update gate
S_t = S_{t-1} + beta_t * (v_t - S_{t-1} k_t) * k_t^T # Delta rule
o_t = S_t q_t # Output retrievalArchitecture
Input [batch, seq_len, embed_dim]
|
v
[Input Projection] -> hidden_size
|
v
+----------------------------------+
| DeltaNet Layer |
| Project to Q, K, V, beta |
| For each timestep: |
| error = v - S @ k |
| S += beta * error * k^T |
| output = S @ q |
+----------------------------------+
| (repeat num_layers)
v
[Layer Norm] -> [Last Timestep]
|
v
Output [batch, hidden_size]Usage
model = DeltaNet.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
dropout: 0.1
)References
- Paper: https://arxiv.org/abs/2102.11174
- Delta rule RNNs: https://arxiv.org/abs/2310.01655
Summary
Functions
Build a DeltaNet model for sequence processing.
Build a single DeltaNet block that can be used as a backbone layer in hybrid architectures.
Default dropout rate
Default hidden dimension
Default number of attention heads
Default number of layers
Epsilon for normalization
Get the output size of a DeltaNet model.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a DeltaNet model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of independent delta rule heads (default: 4):num_layers- Number of DeltaNet layers (default: 4):dropout- Dropout rate between layers (default: 0.1):window_size- Expected sequence length (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
Build a single DeltaNet block that can be used as a backbone layer in hybrid architectures.
Takes input of shape [batch, seq_len, hidden_size] and returns the same shape. Includes pre-norm and residual connection.
Options
:hidden_size- Hidden dimension (default: 256):num_heads- Number of heads (default: 4):name- Layer name prefix (default: "delta_net_block")
@spec default_dropout() :: float()
Default dropout rate
@spec default_num_heads() :: pos_integer()
Default number of attention heads
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec norm_eps() :: float()
Epsilon for normalization
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a DeltaNet model.