Griffin: Hybrid RG-LRU + Local Attention Architecture.
Implements the Griffin architecture from "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models" (De et al., 2024).
Key Innovation: Real-Gated Linear Recurrent Unit (RG-LRU)
Unlike Mamba's selective SSM, Griffin uses a simpler gated recurrence:
r_t = sigma(W_a x_t + b_a) # Recurrence gate
i_t = sigma(W_x x_t + b_x) # Input gate
a_t = a^(c * r_t) # Gated decay (a = sigma(Lambda), c = 8)
h_t = a_t . h_{t-1} + sqrt(1-a_t^2) . (i_t . x_t)The sqrt(1-a_t^2) term ensures hidden state norm is preserved (like a rotation),
enabling stable training at long sequences.
Architecture Pattern
Griffin alternates between RG-LRU and local attention blocks:
- Pattern: 2 RG-LRU blocks -> 1 Local Attention block -> repeat
- Local attention uses sliding window (default 1024 tokens)
- Each block: RMSNorm -> temporal mix -> residual -> RMSNorm -> gated MLP -> residual
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| Griffin Block (RG-LRU) |
| RMSNorm -> RG-LRU -> Residual |
| RMSNorm -> Gated MLP -> Residual |
+-------------------------------------+
| (repeat 2x)
v
+-------------------------------------+
| Griffin Block (Local Attn) |
| RMSNorm -> LocalAttn -> Residual |
| RMSNorm -> Gated MLP -> Residual |
+-------------------------------------+
|
v (repeat pattern)Usage
# Build Griffin backbone
model = Griffin.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
window_size: 60,
local_attn_window: 32
)Compared to Mamba
| Aspect | Mamba | Griffin |
|---|---|---|
| Recurrence | SSM with A,B,C,Delta | Simple gated RNN |
| Parallel scan | Required | Optional (can be sequential) |
| Long-range | Pure recurrence | Hybrid with local attention |
| Parameters | Higher (SSM projections) | Lower (just gates) |
References
- Paper: https://arxiv.org/abs/2402.19427
- Hawk: RG-LRU only variant (no local attention)
Summary
Functions
Build a Griffin model for sequence processing.
Build a Gated MLP layer (used in Griffin blocks).
Build a single Griffin block.
Build a Hawk model (Griffin without local attention).
Build a local (sliding window) attention layer.
Build the Real-Gated Linear Recurrent Unit layer.
Default dropout rate
Default MLP expansion factor
Default hidden dimension
Default local attention window size
Number of attention heads for local attention
Default number of layers (should be divisible by 3 for 2:1 pattern)
Initialize the lambda parameter for RG-LRU.
Get the output size of a Griffin model.
Calculate approximate parameter count for a Griffin model.
Recommended default configuration for sequence processing.
RG-LRU gate constant c (controls decay rate range)
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:expand_factor, pos_integer()} | {:local_attn_window, pos_integer()} | {:num_heads, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()} | {:use_local_attention, boolean()}
Options for build/1.
Functions
Build a Griffin model for sequence processing.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of Griffin blocks (default: 6, divisible by 3):expand_factor- MLP expansion factor (default: 3):local_attn_window- Local attention window size (default: 32):num_heads- Number of attention heads (default: 4):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60):use_local_attention- Include local attention blocks (default: true) Set to false for Hawk variant (pure RG-LRU)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec build_gated_mlp(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()
Build a Gated MLP layer (used in Griffin blocks).
Structure: Linear -> split -> (GeLU, Identity) -> multiply -> Linear
Build a single Griffin block.
Griffin block structure:
- RMSNorm -> Temporal mixing (RG-LRU or Local Attention) -> Residual
- RMSNorm -> Gated MLP -> Residual
Build a Hawk model (Griffin without local attention).
This is a pure RG-LRU model, simpler and faster than Griffin.
Build a local (sliding window) attention layer.
Uses windowed attention for computational efficiency while still capturing short-range dependencies.
Build the Real-Gated Linear Recurrent Unit layer.
RG-LRU equations:
- r_t = sigma(W_a x_t + b_a) # Recurrence gate
- i_t = sigma(W_x x_t + b_x) # Input gate
- a_t = a^(c * r_t) # Gated decay (a = sigma(Lambda), c = 8)
- ht = a_t . h{t-1} + sqrt(1-a_t^2) . (i_t . x_t)
The sqrt(1-a_t^2) normalization ensures the recurrence preserves hidden state magnitude (like a complex rotation).
@spec default_dropout() :: float()
Default dropout rate
@spec default_expand_factor() :: pos_integer()
Default MLP expansion factor
@spec default_local_attn_window() :: pos_integer()
Default local attention window size
@spec default_num_heads() :: pos_integer()
Number of attention heads for local attention
@spec default_num_layers() :: pos_integer()
Default number of layers (should be divisible by 3 for 2:1 pattern)
@spec init_lambda(tuple()) :: Nx.Tensor.t()
Initialize the lambda parameter for RG-LRU.
Lambda is initialized so that a^c is uniformly distributed in [0.9, 0.999]. Since a = sigmoid(lambda) and a^c should be in [0.9, 0.999]:
- a_min = 0.9^(1/c), a_max = 0.999^(1/c)
- lambda = logit(uniform(a_min, a_max))
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Griffin model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Griffin model.
@spec recommended_defaults() :: keyword()
Recommended default configuration for sequence processing.
@spec rg_lru_c() :: float()
RG-LRU gate constant c (controls decay rate range)