# `Edifice.Attention.HGRN`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/hgrn.ex#L1)

HGRN-2: Hierarchically Gated Linear RNN with State Expansion.

HGRN-2 is a linear RNN architecture that uses hierarchical gating and
state expansion to achieve strong performance on sequence modeling tasks
while maintaining O(L) complexity.

## Key Innovation: State Expansion

HGRN-2 expands the hidden state dimension during recurrence, then
contracts back. This allows the model to maintain a richer internal
representation without increasing output complexity:

```
h_expanded = expand(h)  # D -> D*expansion
h_new = gate * h_expanded + (1 - gate) * input
output = contract(h_new)  # D*expansion -> D
```

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|  HGRN-2 Block                        |
|                                      |
|  +- State Expansion ---------------+ |
|  |                               |   |
|  |  h_expanded = Linear(h, D*E)  |   |
|  |                               |   |
|  +-------------------------------+   |
|                                      |
|  +- Hierarchical Gating -----------+ |
|  |                               |   |
|  |  forget_gate = sigmoid(Wf*x)  |   |
|  |  input_gate = sigmoid(Wi*x)   |   |
|  |  h = f*h + i*input            |   |
|  |                               |   |
|  +-------------------------------+   |
|                                      |
|  +- State Contraction -------------+ |
|  |                               |   |
|  |  output = Linear(h, D)        |   |
|  |                               |   |
|  +-------------------------------+   |
+-------------------------------------+
      | (repeat for num_layers)
      v
[batch, hidden_size]
```

## Complexity

| Aspect | Value |
|--------|-------|
| Training Time | O(L) |
| Training Space | O(L) |
| Inference Time | O(1) per step |
| Inference Space | O(1) |

## Usage

    model = HGRN.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6,
      state_expansion: 2
    )

## Reference

- Paper: "HGRN2: Gated Linear RNNs with State Expansion" (arXiv:2404.07904)

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_expansion, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build an HGRN-2 model for sequence processing.

## Options

  - `:embed_dim` - Size of input embedding per timestep (required)
  - `:hidden_size` - Internal hidden dimension D (default: 256)
  - `:num_layers` - Number of HGRN blocks (default: 6)
  - `:state_expansion` - State expansion factor E (default: 2)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:window_size` - Expected sequence length for JIT optimization (default: 60)

## Returns

  An Axon model that outputs [batch, hidden_size] from the last position.

# `build_hgrn_block`

```elixir
@spec build_hgrn_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single HGRN-2 block.

Each block has:
1. Hierarchical gated RNN layer with state expansion
2. Feed-forward network with gating

# `build_hgrn_layer`

```elixir
@spec build_hgrn_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Hierarchical Gated RNN layer with state expansion.

Key components:
1. State expansion: D -> D*E
2. Forget and input gates (hierarchical gating)
3. Recurrent update with parallel scan
4. State contraction: D*E -> D

# `init_cache`

```elixir
@spec init_cache(keyword()) :: map()
```

Initialize hidden state for O(1) incremental inference.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of an HGRN model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for an HGRN model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended default configuration for sequence processing.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
