# `Edifice.Attention.RetNet`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/retnet.ex#L1)

RetNet: Retentive Network - A Successor to Transformer.

Implements the RetNet architecture from "Retentive Network: A Successor to
Transformer for Large Language Models" (Sun et al., Microsoft 2023).

## Key Innovation: Retention Mechanism

RetNet replaces attention with "retention" - a decay-based mechanism:

```
Parallel:   Y = (Q . Theta) . D . (K . Theta)^T . V
Recurrent:  s_n = gamma*s_{n-1} + K_n^T*V_n; o_n = Q_n*s_n
```

Where D is a decay matrix: D[n,m] = gamma^(n-m) if n>=m, else 0.

## Triple Paradigm

The same weights support three computation modes:
- **Parallel**: Training mode, O(L^2) but GPU-parallel
- **Recurrent**: Inference mode, O(1) per token
- **Chunkwise**: Long sequences, O(L) with chunking

## Multi-Scale Retention (MSR)

Different heads use different decay rates for multi-scale modeling:
- gamma_h = 1 - 2^(-5-h) for head h
- GroupNorm instead of LayerNorm (handles different head variances)
- SiLU gating: Y = SiLU(X*W_G) . Retention(X)*W_O

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       RetNet Block                   |
|  LayerNorm -> MSR -> Residual        |
|  LayerNorm -> FFN -> Residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]
```

## Usage

    # Build RetNet backbone
    model = RetNet.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6,
      num_heads: 4
    )

## Comparison

| Mode | Time | Memory | Best For |
|------|------|--------|----------|
| Parallel | O(L^2) | O(L^2) | Training |
| Recurrent | O(1) | O(1) | Inference |
| Chunkwise | O(L) | O(C) | Long sequences |

## References
- Paper: https://arxiv.org/abs/2307.08621
- Code: https://github.com/microsoft/unilm/tree/master/retnet

# `build_opt`

```elixir
@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a RetNet model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per timestep (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Number of RetNet blocks (default: 6)
  - `:num_heads` - Number of retention heads (default: 4)
  - `:expand_factor` - FFN expansion factor (default: 2)
  - `:dropout` - Dropout rate (default: 0.0)
  - `:window_size` - Expected sequence length (default: 60)
  - `:mode` - Computation mode: :parallel, :recurrent, :chunkwise (default: :parallel)

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `build_multi_scale_retention`

```elixir
@spec build_multi_scale_retention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build Multi-Scale Retention layer.

MSR uses different decay rates (gamma) per head for multi-scale modeling:
- gamma_h = 1 - 2^(-5-h) for head h
- SiLU gating: Y = SiLU(X*W_G) . Retention(X)*W_O
- GroupNorm for handling different head variances

# `build_retnet_block`

```elixir
@spec build_retnet_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single RetNet block.

RetNet block structure:
1. LayerNorm -> Multi-Scale Retention -> Residual
2. LayerNorm -> FFN -> Residual

# `default_dropout`

```elixir
@spec default_dropout() :: float()
```

Default dropout rate

# `default_expand_factor`

```elixir
@spec default_expand_factor() :: pos_integer()
```

Default feedforward expansion factor

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension

# `default_num_heads`

```elixir
@spec default_num_heads() :: pos_integer()
```

Default number of retention heads

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of layers

# `eps`

```elixir
@spec eps() :: float()
```

Epsilon for numerical stability

# `init_retention_state`

```elixir
@spec init_retention_state(pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()
```

Initialize retention state for recurrent inference.

Returns a zero-initialized state tensor of shape [batch, heads, head_dim, head_dim].

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a RetNet model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a RetNet model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended default configuration for sequence processing.

# `recurrent_retention_step`

```elixir
@spec recurrent_retention_step(
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Build recurrent retention state update.

Recurrent formulation for O(1) inference:
- s_n = gamma * s_{n-1} + K_n^T * V_n
- o_n = Q_n * s_n

This is used during inference when processing one token at a time.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
