# `Edifice.Attention.Griffin`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/attention/griffin.ex#L1)

Griffin: Hybrid RG-LRU + Local Attention Architecture.

Implements the Griffin architecture from "Griffin: Mixing Gated Linear Recurrences
with Local Attention for Efficient Language Models" (De et al., 2024).

## Key Innovation: Real-Gated Linear Recurrent Unit (RG-LRU)

Unlike Mamba's selective SSM, Griffin uses a simpler gated recurrence:

```
r_t = sigma(W_a x_t + b_a)           # Recurrence gate
i_t = sigma(W_x x_t + b_x)           # Input gate
a_t = a^(c * r_t)                    # Gated decay (a = sigma(Lambda), c = 8)
h_t = a_t . h_{t-1} + sqrt(1-a_t^2) . (i_t . x_t)
```

The `sqrt(1-a_t^2)` term ensures hidden state norm is preserved (like a rotation),
enabling stable training at long sequences.

## Architecture Pattern

Griffin alternates between RG-LRU and local attention blocks:
- Pattern: 2 RG-LRU blocks -> 1 Local Attention block -> repeat
- Local attention uses sliding window (default 1024 tokens)
- Each block: RMSNorm -> temporal mix -> residual -> RMSNorm -> gated MLP -> residual

```
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       Griffin Block (RG-LRU)        |
|  RMSNorm -> RG-LRU -> Residual      |
|  RMSNorm -> Gated MLP -> Residual   |
+-------------------------------------+
      | (repeat 2x)
      v
+-------------------------------------+
|    Griffin Block (Local Attn)       |
|  RMSNorm -> LocalAttn -> Residual   |
|  RMSNorm -> Gated MLP -> Residual   |
+-------------------------------------+
      |
      v (repeat pattern)
```

## Usage

    # Build Griffin backbone
    model = Griffin.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6,
      window_size: 60,
      local_attn_window: 32
    )

## Compared to Mamba

| Aspect | Mamba | Griffin |
|--------|-------|---------|
| Recurrence | SSM with A,B,C,Delta | Simple gated RNN |
| Parallel scan | Required | Optional (can be sequential) |
| Long-range | Pure recurrence | Hybrid with local attention |
| Parameters | Higher (SSM projections) | Lower (just gates) |

## References
- Paper: https://arxiv.org/abs/2402.19427
- Hawk: RG-LRU only variant (no local attention)

# `build_opt`

```elixir
@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:local_attn_window, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
  | {:use_local_attention, boolean()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Griffin model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per timestep (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Number of Griffin blocks (default: 6, divisible by 3)
  - `:expand_factor` - MLP expansion factor (default: 3)
  - `:local_attn_window` - Local attention window size (default: 32)
  - `:num_heads` - Number of attention heads (default: 4)
  - `:dropout` - Dropout rate (default: 0.0)
  - `:window_size` - Expected sequence length (default: 60)
  - `:use_local_attention` - Include local attention blocks (default: true)
    Set to false for Hawk variant (pure RG-LRU)

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `build_gated_mlp`

```elixir
@spec build_gated_mlp(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()
```

Build a Gated MLP layer (used in Griffin blocks).

Structure: Linear -> split -> (GeLU, Identity) -> multiply -> Linear

# `build_griffin_block`

```elixir
@spec build_griffin_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single Griffin block.

Griffin block structure:
1. RMSNorm -> Temporal mixing (RG-LRU or Local Attention) -> Residual
2. RMSNorm -> Gated MLP -> Residual

# `build_hawk`

```elixir
@spec build_hawk(keyword()) :: Axon.t()
```

Build a Hawk model (Griffin without local attention).

This is a pure RG-LRU model, simpler and faster than Griffin.

# `build_local_attention`

```elixir
@spec build_local_attention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a local (sliding window) attention layer.

Uses windowed attention for computational efficiency while
still capturing short-range dependencies.

# `build_rg_lru_layer`

```elixir
@spec build_rg_lru_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Real-Gated Linear Recurrent Unit layer.

RG-LRU equations:
- r_t = sigma(W_a x_t + b_a)           # Recurrence gate
- i_t = sigma(W_x x_t + b_x)           # Input gate
- a_t = a^(c * r_t)                    # Gated decay (a = sigma(Lambda), c = 8)
- h_t = a_t . h_{t-1} + sqrt(1-a_t^2) . (i_t . x_t)

The sqrt(1-a_t^2) normalization ensures the recurrence preserves
hidden state magnitude (like a complex rotation).

# `default_dropout`

```elixir
@spec default_dropout() :: float()
```

Default dropout rate

# `default_expand_factor`

```elixir
@spec default_expand_factor() :: pos_integer()
```

Default MLP expansion factor

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension

# `default_local_attn_window`

```elixir
@spec default_local_attn_window() :: pos_integer()
```

Default local attention window size

# `default_num_heads`

```elixir
@spec default_num_heads() :: pos_integer()
```

Number of attention heads for local attention

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of layers (should be divisible by 3 for 2:1 pattern)

# `init_lambda`

```elixir
@spec init_lambda(tuple()) :: Nx.Tensor.t()
```

Initialize the lambda parameter for RG-LRU.

Lambda is initialized so that a^c is uniformly distributed in [0.9, 0.999].
Since a = sigmoid(lambda) and a^c should be in [0.9, 0.999]:
- a_min = 0.9^(1/c), a_max = 0.999^(1/c)
- lambda = logit(uniform(a_min, a_max))

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a Griffin model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a Griffin model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended default configuration for sequence processing.

# `rg_lru_c`

```elixir
@spec rg_lru_c() :: float()
```

RG-LRU gate constant c (controls decay rate range)

---

*Consult [api-reference.md](api-reference.md) for complete listing*
