Edifice.Recurrent.DeltaNet (Edifice v0.2.0)

Copy Markdown View Source

DeltaNet - Linear Attention with Delta Rule.

Implements linear attention with the delta rule update from "Linear Transformers with Learnable Kernel Functions are Better In-Context Models" (Schlag et al., 2021) and subsequent work.

DeltaNet maintains an associative memory matrix S that is updated using the delta rule, which corrects previous associations rather than blindly accumulating them. This gives it superior retrieval accuracy compared to standard linear attention.

Key Innovations

  • Delta rule update: St = S{t-1} + betat * (v_t - S{t-1} k_t) k_t^T
  • Error-correcting: Subtracts the current retrieval S_{t-1} k_t before adding
  • Learnable beta: Controls update rate per-token via a gate
  • Linear complexity: O(d^2) memory vs O(n*d) for softmax attention

Equations

q_t = W_q x_t                          # Query projection
k_t = W_k x_t                          # Key projection (L2 normalized)
v_t = W_v x_t                          # Value projection
beta_t = sigmoid(W_beta x_t)           # Update gate
S_t = S_{t-1} + beta_t * (v_t - S_{t-1} k_t) * k_t^T   # Delta rule
o_t = S_t q_t                          # Output retrieval

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
[Input Projection] -> hidden_size
      |
      v
+----------------------------------+
|      DeltaNet Layer              |
|  Project to Q, K, V, beta        |
|  For each timestep:              |
|    error = v - S @ k             |
|    S += beta * error * k^T       |
|    output = S @ q                |
+----------------------------------+
      | (repeat num_layers)
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]

Usage

model = DeltaNet.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  dropout: 0.1
)

References

Summary

Types

Options for build/1.

Functions

Build a DeltaNet model for sequence processing.

Build a single DeltaNet block that can be used as a backbone layer in hybrid architectures.

Default dropout rate

Default hidden dimension

Default number of attention heads

Default number of layers

Epsilon for normalization

Get the output size of a DeltaNet model.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a DeltaNet model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of independent delta rule heads (default: 4)
  • :num_layers - Number of DeltaNet layers (default: 4)
  • :dropout - Dropout rate between layers (default: 0.1)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_block(input, opts \\ [])

@spec build_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single DeltaNet block that can be used as a backbone layer in hybrid architectures.

Takes input of shape [batch, seq_len, hidden_size] and returns the same shape. Includes pre-norm and residual connection.

Options

  • :hidden_size - Hidden dimension (default: 256)
  • :num_heads - Number of heads (default: 4)
  • :name - Layer name prefix (default: "delta_net_block")

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_heads()

@spec default_num_heads() :: pos_integer()

Default number of attention heads

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

norm_eps()

@spec norm_eps() :: float()

Epsilon for normalization

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a DeltaNet model.