HGRN-2: Hierarchically Gated Linear RNN with State Expansion.
HGRN-2 is a linear RNN architecture that uses hierarchical gating and state expansion to achieve strong performance on sequence modeling tasks while maintaining O(L) complexity.
Key Innovation: State Expansion
HGRN-2 expands the hidden state dimension during recurrence, then contracts back. This allows the model to maintain a richer internal representation without increasing output complexity:
h_expanded = expand(h) # D -> D*expansion
h_new = gate * h_expanded + (1 - gate) * input
output = contract(h_new) # D*expansion -> DArchitecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| HGRN-2 Block |
| |
| +- State Expansion ---------------+ |
| | | |
| | h_expanded = Linear(h, D*E) | |
| | | |
| +-------------------------------+ |
| |
| +- Hierarchical Gating -----------+ |
| | | |
| | forget_gate = sigmoid(Wf*x) | |
| | input_gate = sigmoid(Wi*x) | |
| | h = f*h + i*input | |
| | | |
| +-------------------------------+ |
| |
| +- State Contraction -------------+ |
| | | |
| | output = Linear(h, D) | |
| | | |
| +-------------------------------+ |
+-------------------------------------+
| (repeat for num_layers)
v
[batch, hidden_size]Complexity
| Aspect | Value |
|---|---|
| Training Time | O(L) |
| Training Space | O(L) |
| Inference Time | O(1) per step |
| Inference Space | O(1) |
Usage
model = HGRN.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
state_expansion: 2
)Reference
- Paper: "HGRN2: Gated Linear RNNs with State Expansion" (arXiv:2404.07904)
Summary
Functions
Build an HGRN-2 model for sequence processing.
Build a single HGRN-2 block.
Build the Hierarchical Gated RNN layer with state expansion.
Initialize hidden state for O(1) incremental inference.
Get the output size of an HGRN model.
Calculate approximate parameter count for an HGRN model.
Recommended default configuration for sequence processing.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:state_expansion, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build an HGRN-2 model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension D (default: 256):num_layers- Number of HGRN blocks (default: 6):state_expansion- State expansion factor E (default: 2):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build a single HGRN-2 block.
Each block has:
- Hierarchical gated RNN layer with state expansion
- Feed-forward network with gating
Build the Hierarchical Gated RNN layer with state expansion.
Key components:
- State expansion: D -> D*E
- Forget and input gates (hierarchical gating)
- Recurrent update with parallel scan
- State contraction: D*E -> D
Initialize hidden state for O(1) incremental inference.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of an HGRN model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for an HGRN model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.