Edifice.Attention.RWKV (Edifice v0.2.0)

Copy Markdown View Source

RWKV-7 "Goose": Linear attention with O(1) space complexity.

RWKV (Receptance Weighted Key Value) is a linear attention architecture that combines the parallelizable training of Transformers with the efficient O(1) inference of RNNs.

Key Innovation: Generalized Delta Rule

RWKV-7 uses a generalized delta rule that surpasses the TC0 constraint, enabling it to comprehensively outperform Transformers on many tasks.

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  RWKV Block                          |
|                                      |
|  +----------------------------------+
|  | Time-Mixing (WKV Attention)      |
|  | - R-gate: receptance             |
|  | - W: time decay                  |
|  | - K, V: key-value pairs          |
|  | - time_first: first token bias   |
|  +----------------------------------+
|                                      |
|  +----------------------------------+
|  | Channel-Mixing (FFN)              |
|  | - R-gate * K-gate                 |
|  +----------------------------------+
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]

Complexity

PhaseTimeSpace
TrainingO(L)O(L)
InferenceO(1) per stepO(1)

Key Difference from Mamba

AspectRWKVMamba
AttentionWKV (weighted key-value)SSM (state space)
StateO(1) fixed sizeO(L) for full sequence
DecayLearned per-channelInput-dependent
GatingR-gate, K-gateSiLU gating

Usage

model = RWKV.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6
)

References

Summary

Types

Options for build/1.

Functions

Build an RWKV-7 model for sequence processing.

Build the Channel-Mixing sub-block (FFN with gating).

Build a single RWKV block.

Build the Time-Mixing sub-block (WKV attention).

Initialize hidden state for O(1) incremental inference.

Get the output size of an RWKV model.

Calculate approximate parameter count for an RWKV model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:head_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an RWKV-7 model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of RWKV blocks (default: 6)
  • :head_size - Size per attention head (default: 64)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_channel_mixing(input, opts)

@spec build_channel_mixing(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Channel-Mixing sub-block (FFN with gating).

Channel-mixing uses a gated FFN structure:

output = sigmoid(r) * (k * v)

Where:

  • r: receptance gate
  • k: key (square activation)
  • v: value projection

build_rwkv_block(input, opts)

@spec build_rwkv_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single RWKV block.

Each block has two sub-blocks:

  1. Time-mixing: WKV attention mechanism
  2. Channel-mixing: Feed-forward with gating

build_time_mixing(input, opts)

@spec build_time_mixing(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Time-Mixing sub-block (WKV attention).

Time-mixing implements the WKV (Weighted Key-Value) attention mechanism:

wkv[t] = (sum_{i<t} exp(w*(t-1-i) + k[i]) * v[i] + exp(u + k[t]) * v[t]) /
         (sum_{i<t} exp(w*(t-1-i) + k[i]) + exp(u + k[t]))

Where:

  • w: learned time decay (per head)
  • u: learned "time_first" bias for current token
  • k, v: keys and values from input
  • r: receptance gate

Output = sigmoid(r) * wkv

init_cache(opts \\ [])

@spec init_cache(keyword()) :: map()

Initialize hidden state for O(1) incremental inference.

RWKV's key advantage: constant memory per inference step.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an RWKV model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an RWKV model.