# `Edifice.Recurrent.DeltaNet`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/recurrent/delta_net.ex#L1)

DeltaNet - Linear Attention with Delta Rule.

Implements linear attention with the delta rule update from
"Linear Transformers with Learnable Kernel Functions are Better
In-Context Models" (Schlag et al., 2021) and subsequent work.

DeltaNet maintains an associative memory matrix S that is updated
using the delta rule, which corrects previous associations rather
than blindly accumulating them. This gives it superior retrieval
accuracy compared to standard linear attention.

## Key Innovations

- **Delta rule update**: S_t = S_{t-1} + beta_t * (v_t - S_{t-1} k_t) k_t^T
- **Error-correcting**: Subtracts the current retrieval S_{t-1} k_t before adding
- **Learnable beta**: Controls update rate per-token via a gate
- **Linear complexity**: O(d^2) memory vs O(n*d) for softmax attention

## Equations

```
q_t = W_q x_t                          # Query projection
k_t = W_k x_t                          # Key projection (L2 normalized)
v_t = W_v x_t                          # Value projection
beta_t = sigmoid(W_beta x_t)           # Update gate
S_t = S_{t-1} + beta_t * (v_t - S_{t-1} k_t) * k_t^T   # Delta rule
o_t = S_t q_t                          # Output retrieval
```

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
[Input Projection] -> hidden_size
      |
      v
+----------------------------------+
|      DeltaNet Layer              |
|  Project to Q, K, V, beta        |
|  For each timestep:              |
|    error = v - S @ k             |
|    S += beta * error * k^T       |
|    output = S @ q                |
+----------------------------------+
      | (repeat num_layers)
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]
```

## Usage

    model = DeltaNet.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 4,
      dropout: 0.1
    )

## References
- Paper: https://arxiv.org/abs/2102.11174
- Delta rule RNNs: https://arxiv.org/abs/2310.01655

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a DeltaNet model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_heads` - Number of independent delta rule heads (default: 4)
  - `:num_layers` - Number of DeltaNet layers (default: 4)
  - `:dropout` - Dropout rate between layers (default: 0.1)
  - `:window_size` - Expected sequence length (default: 60)

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `build_block`

```elixir
@spec build_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single DeltaNet block that can be used as a backbone layer
in hybrid architectures.

Takes input of shape [batch, seq_len, hidden_size] and returns the same shape.
Includes pre-norm and residual connection.

## Options
  - `:hidden_size` - Hidden dimension (default: 256)
  - `:num_heads` - Number of heads (default: 4)
  - `:name` - Layer name prefix (default: "delta_net_block")

# `default_dropout`

```elixir
@spec default_dropout() :: float()
```

Default dropout rate

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension

# `default_num_heads`

```elixir
@spec default_num_heads() :: pos_integer()
```

Default number of attention heads

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of layers

# `norm_eps`

```elixir
@spec norm_eps() :: float()
```

Epsilon for normalization

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a DeltaNet model.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
