Edifice.Attention.HGRN (Edifice v0.2.0)

Copy Markdown View Source

HGRN-2: Hierarchically Gated Linear RNN with State Expansion.

HGRN-2 is a linear RNN architecture that uses hierarchical gating and state expansion to achieve strong performance on sequence modeling tasks while maintaining O(L) complexity.

Key Innovation: State Expansion

HGRN-2 expands the hidden state dimension during recurrence, then contracts back. This allows the model to maintain a richer internal representation without increasing output complexity:

h_expanded = expand(h)  # D -> D*expansion
h_new = gate * h_expanded + (1 - gate) * input
output = contract(h_new)  # D*expansion -> D

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  HGRN-2 Block                        |
|                                      |
|  +- State Expansion ---------------+ |
|  |                               |   |
|  |  h_expanded = Linear(h, D*E)  |   |
|  |                               |   |
|  +-------------------------------+   |
|                                      |
|  +- Hierarchical Gating -----------+ |
|  |                               |   |
|  |  forget_gate = sigmoid(Wf*x)  |   |
|  |  input_gate = sigmoid(Wi*x)   |   |
|  |  h = f*h + i*input            |   |
|  |                               |   |
|  +-------------------------------+   |
|                                      |
|  +- State Contraction -------------+ |
|  |                               |   |
|  |  output = Linear(h, D)        |   |
|  |                               |   |
|  +-------------------------------+   |
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]

Complexity

AspectValue
Training TimeO(L)
Training SpaceO(L)
Inference TimeO(1) per step
Inference SpaceO(1)

Usage

model = HGRN.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  state_expansion: 2
)

Reference

  • Paper: "HGRN2: Gated Linear RNNs with State Expansion" (arXiv:2404.07904)

Summary

Types

Options for build/1.

Functions

Build an HGRN-2 model for sequence processing.

Build a single HGRN-2 block.

Build the Hierarchical Gated RNN layer with state expansion.

Initialize hidden state for O(1) incremental inference.

Get the output size of an HGRN model.

Calculate approximate parameter count for an HGRN model.

Recommended default configuration for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_expansion, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an HGRN-2 model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension D (default: 256)
  • :num_layers - Number of HGRN blocks (default: 6)
  • :state_expansion - State expansion factor E (default: 2)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_hgrn_block(input, opts)

@spec build_hgrn_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single HGRN-2 block.

Each block has:

  1. Hierarchical gated RNN layer with state expansion
  2. Feed-forward network with gating

build_hgrn_layer(input, opts)

@spec build_hgrn_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Hierarchical Gated RNN layer with state expansion.

Key components:

  1. State expansion: D -> D*E
  2. Forget and input gates (hierarchical gating)
  3. Recurrent update with parallel scan
  4. State contraction: D*E -> D

init_cache(opts \\ [])

@spec init_cache(keyword()) :: map()

Initialize hidden state for O(1) incremental inference.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an HGRN model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an HGRN model.