Edifice.Attention.Griffin (Edifice v0.2.0)

Copy Markdown View Source

Griffin: Hybrid RG-LRU + Local Attention Architecture.

Implements the Griffin architecture from "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models" (De et al., 2024).

Key Innovation: Real-Gated Linear Recurrent Unit (RG-LRU)

Unlike Mamba's selective SSM, Griffin uses a simpler gated recurrence:

r_t = sigma(W_a x_t + b_a)           # Recurrence gate
i_t = sigma(W_x x_t + b_x)           # Input gate
a_t = a^(c * r_t)                    # Gated decay (a = sigma(Lambda), c = 8)
h_t = a_t . h_{t-1} + sqrt(1-a_t^2) . (i_t . x_t)

The sqrt(1-a_t^2) term ensures hidden state norm is preserved (like a rotation), enabling stable training at long sequences.

Architecture Pattern

Griffin alternates between RG-LRU and local attention blocks:

  • Pattern: 2 RG-LRU blocks -> 1 Local Attention block -> repeat
  • Local attention uses sliding window (default 1024 tokens)
  • Each block: RMSNorm -> temporal mix -> residual -> RMSNorm -> gated MLP -> residual
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       Griffin Block (RG-LRU)        |
|  RMSNorm -> RG-LRU -> Residual      |
|  RMSNorm -> Gated MLP -> Residual   |
+-------------------------------------+
      | (repeat 2x)
      v
+-------------------------------------+
|    Griffin Block (Local Attn)       |
|  RMSNorm -> LocalAttn -> Residual   |
|  RMSNorm -> Gated MLP -> Residual   |
+-------------------------------------+
      |
      v (repeat pattern)

Usage

# Build Griffin backbone
model = Griffin.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  window_size: 60,
  local_attn_window: 32
)

Compared to Mamba

AspectMambaGriffin
RecurrenceSSM with A,B,C,DeltaSimple gated RNN
Parallel scanRequiredOptional (can be sequential)
Long-rangePure recurrenceHybrid with local attention
ParametersHigher (SSM projections)Lower (just gates)

References

Summary

Types

Options for build/1.

Functions

Build a Griffin model for sequence processing.

Build a Gated MLP layer (used in Griffin blocks).

Build a single Griffin block.

Build a Hawk model (Griffin without local attention).

Build a local (sliding window) attention layer.

Build the Real-Gated Linear Recurrent Unit layer.

Default dropout rate

Default MLP expansion factor

Default hidden dimension

Default local attention window size

Number of attention heads for local attention

Default number of layers (should be divisible by 3 for 2:1 pattern)

Initialize the lambda parameter for RG-LRU.

Get the output size of a Griffin model.

Calculate approximate parameter count for a Griffin model.

Recommended default configuration for sequence processing.

RG-LRU gate constant c (controls decay rate range)

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:local_attn_window, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
  | {:use_local_attention, boolean()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Griffin model for sequence processing.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of Griffin blocks (default: 6, divisible by 3)
  • :expand_factor - MLP expansion factor (default: 3)
  • :local_attn_window - Local attention window size (default: 32)
  • :num_heads - Number of attention heads (default: 4)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)
  • :use_local_attention - Include local attention blocks (default: true) Set to false for Hawk variant (pure RG-LRU)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_gated_mlp(input, hidden_size, expand_factor, name)

@spec build_gated_mlp(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()

Build a Gated MLP layer (used in Griffin blocks).

Structure: Linear -> split -> (GeLU, Identity) -> multiply -> Linear

build_griffin_block(input, opts \\ [])

@spec build_griffin_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single Griffin block.

Griffin block structure:

  1. RMSNorm -> Temporal mixing (RG-LRU or Local Attention) -> Residual
  2. RMSNorm -> Gated MLP -> Residual

build_hawk(opts \\ [])

@spec build_hawk(keyword()) :: Axon.t()

Build a Hawk model (Griffin without local attention).

This is a pure RG-LRU model, simpler and faster than Griffin.

build_local_attention(input, opts \\ [])

@spec build_local_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a local (sliding window) attention layer.

Uses windowed attention for computational efficiency while still capturing short-range dependencies.

build_rg_lru_layer(input, opts \\ [])

@spec build_rg_lru_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Real-Gated Linear Recurrent Unit layer.

RG-LRU equations:

  • r_t = sigma(W_a x_t + b_a) # Recurrence gate
  • i_t = sigma(W_x x_t + b_x) # Input gate
  • a_t = a^(c * r_t) # Gated decay (a = sigma(Lambda), c = 8)
  • ht = a_t . h{t-1} + sqrt(1-a_t^2) . (i_t . x_t)

The sqrt(1-a_t^2) normalization ensures the recurrence preserves hidden state magnitude (like a complex rotation).

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_expand_factor()

@spec default_expand_factor() :: pos_integer()

Default MLP expansion factor

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_local_attn_window()

@spec default_local_attn_window() :: pos_integer()

Default local attention window size

default_num_heads()

@spec default_num_heads() :: pos_integer()

Number of attention heads for local attention

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers (should be divisible by 3 for 2:1 pattern)

init_lambda(shape)

@spec init_lambda(tuple()) :: Nx.Tensor.t()

Initialize the lambda parameter for RG-LRU.

Lambda is initialized so that a^c is uniformly distributed in [0.9, 0.999]. Since a = sigmoid(lambda) and a^c should be in [0.9, 0.999]:

  • a_min = 0.9^(1/c), a_max = 0.999^(1/c)
  • lambda = logit(uniform(a_min, a_max))

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Griffin model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Griffin model.

rg_lru_c()

@spec rg_lru_c() :: float()

RG-LRU gate constant c (controls decay rate range)