# `Edifice.Attention.RWKV`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/rwkv.ex#L1)

RWKV-7 "Goose": Linear attention with O(1) space complexity.

RWKV (Receptance Weighted Key Value) is a linear attention architecture that
combines the parallelizable training of Transformers with the efficient O(1)
inference of RNNs.

## Key Innovation: Generalized Delta Rule

RWKV-7 uses a generalized delta rule that surpasses the TC0 constraint,
enabling it to comprehensively outperform Transformers on many tasks.

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  RWKV Block                          |
|                                      |
|  +----------------------------------+
|  | Time-Mixing (WKV Attention)      |
|  | - R-gate: receptance             |
|  | - W: time decay                  |
|  | - K, V: key-value pairs          |
|  | - time_first: first token bias   |
|  +----------------------------------+
|                                      |
|  +----------------------------------+
|  | Channel-Mixing (FFN)              |
|  | - R-gate * K-gate                 |
|  +----------------------------------+
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]
```

## Complexity

| Phase | Time | Space |
|-------|------|-------|
| Training | O(L) | O(L) |
| Inference | O(1) per step | O(1) |

## Key Difference from Mamba

| Aspect | RWKV | Mamba |
|--------|------|-------|
| Attention | WKV (weighted key-value) | SSM (state space) |
| State | O(1) fixed size | O(L) for full sequence |
| Decay | Learned per-channel | Input-dependent |
| Gating | R-gate, K-gate | SiLU gating |

## Usage

    model = RWKV.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6
    )

## References

- RWKV-7 "Goose" architecture wiki: https://wiki.rwkv.com/basic/architecture.html
- Paper: "RWKV: Reinventing RNNs for the Transformer Era" (arXiv:2305.13048)
- Deployment: Shipped to 1.5B Windows devices for on-device Copilot

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:head_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build an RWKV-7 model for sequence processing.

## Options

  - `:embed_dim` - Size of input embedding per timestep (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Number of RWKV blocks (default: 6)
  - `:head_size` - Size per attention head (default: 64)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:window_size` - Expected sequence length for JIT optimization (default: 60)

## Returns

  An Axon model that outputs [batch, hidden_size] from the last position.

# `build_channel_mixing`

```elixir
@spec build_channel_mixing(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Channel-Mixing sub-block (FFN with gating).

Channel-mixing uses a gated FFN structure:
```
output = sigmoid(r) * (k * v)
```

Where:
- r: receptance gate
- k: key (square activation)
- v: value projection

# `build_rwkv_block`

```elixir
@spec build_rwkv_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single RWKV block.

Each block has two sub-blocks:
1. Time-mixing: WKV attention mechanism
2. Channel-mixing: Feed-forward with gating

# `build_time_mixing`

```elixir
@spec build_time_mixing(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Time-Mixing sub-block (WKV attention).

Time-mixing implements the WKV (Weighted Key-Value) attention mechanism:

```
wkv[t] = (sum_{i<t} exp(w*(t-1-i) + k[i]) * v[i] + exp(u + k[t]) * v[t]) /
         (sum_{i<t} exp(w*(t-1-i) + k[i]) + exp(u + k[t]))
```

Where:
- w: learned time decay (per head)
- u: learned "time_first" bias for current token
- k, v: keys and values from input
- r: receptance gate

Output = sigmoid(r) * wkv

# `init_cache`

```elixir
@spec init_cache(keyword()) :: map()
```

Initialize hidden state for O(1) incremental inference.

RWKV's key advantage: constant memory per inference step.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of an RWKV model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for an RWKV model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended default configuration for sequence processing.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
