Edifice.SSM.Hybrid (Edifice v0.2.0)

Copy Markdown View Source

Configurable Hybrid Backbone+Attention architecture for efficient sequence modeling.

Originally based on AI21's Jamba architecture, this module interleaves a configurable backbone (Mamba by default) with periodic attention layers. The backbone handles local/sequential context efficiently, while attention captures long-range dependencies.

Supported Backbones

BackboneModuleKey Mechanism
:mamba (default)GatedSSMSelective state space model
:gruAxon.gruGated recurrent unit
:rwkvRWKVLinear attention (WKV)
:delta_netDeltaNetDelta rule linear attention
:gated_delta_netGatedDeltaNetGated delta rule
:griffin_lruGriffin RG-LRUReal-gated linear recurrent unit
:customUser-providedAny (Axon.t(), keyword()) -> Axon.t()

Architecture Pattern

Input [batch, seq_len, embed_dim]
      
      

  Backbone Block 1                    

  Backbone Block 2                    

  Backbone Block 3                    

  Attention Block (every N layers)     <- Long-range dependencies

  Backbone Block 4                    

  ...                                 

      
      
[batch, hidden_size]

Key Advantages

  1. Efficiency: Most layers are O(L) backbone blocks
  2. Long-range: Periodic attention captures distant dependencies
  3. Flexible: Swap backbone without changing the hybrid structure
  4. Memory: Far less than pure attention, slightly more than pure backbone

Usage

# Default hybrid (Mamba backbone, 3:1 ratio)
model = Hybrid.build(
  embed_dim: 256,
  hidden_size: 256,
  num_layers: 8,
  attention_every: 4
)

# GRU backbone (classic RNN + attention)
model = Hybrid.build(
  embed_dim: 256,
  num_layers: 6,
  backbone: :gru,
  attention_every: 3
)

# Gated DeltaNet backbone (linear attention + full attention)
model = Hybrid.build(
  embed_dim: 256,
  num_layers: 6,
  backbone: :gated_delta_net,
  attention_every: 3
)

Summary

Types

Options for build/1.

Functions

Build a hybrid backbone+attention model.

Build an attention layer with residual connection and FFN.

Build a backbone layer based on the configured backbone type.

Build a DeltaNet backbone layer. Delegates to DeltaNet.build_block/2 which includes pre-norm and residual connection.

Build a Gated DeltaNet backbone layer with pre-norm and residual connection.

Build a GRU backbone layer with pre-norm and residual connection.

Build a Mamba layer with residual connection.

Build an RWKV backbone layer with pre-norm and residual connection.

Get the layer pattern for a given configuration.

Get the output size of a hybrid model.

Calculate approximate parameter count for a hybrid model.

Get recommended defaults for real-time sequence processing.

Types

build_opt()

@type build_opt() ::
  {:attention_every, pos_integer()}
  | {:backbone, atom() | {module(), atom()}}
  | {:chunk_size, pos_integer()}
  | {:chunked_attention, boolean()}
  | {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:memory_efficient_attention, boolean()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:pre_norm, boolean()}
  | {:qk_layernorm, boolean()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:use_sliding_window, boolean()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a hybrid backbone+attention model.

Options

Architecture:

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Total number of layers (default: 6)
  • :attention_every - Insert attention every N layers (default: 3)
  • :backbone - Backbone type for non-attention layers (default: :mamba). Supported: :mamba, :gru, :rwkv, :delta_net, :gated_delta_net, :griffin_lru, or a {module, function} tuple for custom backbones. Custom functions must accept (input :: Axon.t(), opts :: keyword()) :: Axon.t().

Mamba-specific (when backbone: :mamba):

  • :state_size - SSM state dimension (default: 16)
  • :expand_factor - Mamba expansion factor (default: 2)
  • :conv_size - Causal conv kernel size (default: 4)

Attention-specific:

  • :num_heads - Number of attention heads (default: 4)
  • :head_dim - Dimension per attention head (default: 64)
  • :window_size - Attention window size (default: 60)
  • :use_sliding_window - Use sliding window vs full attention (default: true)
  • :qk_layernorm - Normalize Q and K before attention (default: true, stabilizes training)

General:

  • :dropout - Dropout rate (default: 0.1)
  • :seq_len - Fixed sequence length for JIT optimization (default: window_size)
  • :pre_norm - Use Pre-LayerNorm (default: true, more stable than Post-LayerNorm)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

Examples

# Mamba backbone (default, Jamba-style)
model = Hybrid.build(
  embed_dim: 256,
  hidden_size: 256,
  num_layers: 6,
  attention_every: 3
)

# GRU backbone
model = Hybrid.build(
  embed_dim: 256,
  num_layers: 6,
  backbone: :gru,
  attention_every: 3
)

build_attention_layer(input, opts)

@spec build_attention_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build an attention layer with residual connection and FFN.

Options

  • :pre_norm - If true, apply LayerNorm before block (Pre-LN, more stable).
  • :qk_layernorm - If true, normalize Q and K before attention (stabilizes training).
  • :chunked - If true, use chunked attention for lower memory usage (default: false).
  • :memory_efficient - If true, use memory-efficient attention with online softmax for true O(n) memory (default: false).
  • :chunk_size - Chunk size when using chunked or memory-efficient attention (default: 32).

build_backbone_layer(input, backbone, opts)

@spec build_backbone_layer(Axon.t(), atom() | {module(), atom()}, keyword()) ::
  Axon.t()

Build a backbone layer based on the configured backbone type.

Dispatches to the appropriate block builder. All backbone types follow the same contract: (Axon.t(), keyword()) -> Axon.t(), taking [batch, seq, hidden] input and returning the same shape with a residual connection.

Supported Backbones

  • :mamba - GatedSSM Mamba block (default)
  • :gru - GRU recurrent layer
  • :rwkv - RWKV linear attention block
  • :delta_net - DeltaNet delta rule block
  • :gated_delta_net - Gated DeltaNet block
  • :griffin_lru - Griffin RG-LRU block
  • {module, function} - Custom backbone; called as module.function(input, opts)

build_delta_net_layer(input, opts)

@spec build_delta_net_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a DeltaNet backbone layer. Delegates to DeltaNet.build_block/2 which includes pre-norm and residual connection.

build_gated_delta_net_layer(input, opts)

@spec build_gated_delta_net_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a Gated DeltaNet backbone layer with pre-norm and residual connection.

build_gru_layer(input, opts)

@spec build_gru_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a GRU backbone layer with pre-norm and residual connection.

Wraps Axon.gru in the same pre-norm + residual pattern as Mamba layers.

build_mamba_layer(input, opts)

@spec build_mamba_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a Mamba layer with residual connection.

Options

  • :pre_norm - If true, apply LayerNorm before block (Pre-LN, more stable).
              If false, apply after residual (Post-LN, original transformer style).

build_rwkv_layer(input, opts)

@spec build_rwkv_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build an RWKV backbone layer with pre-norm and residual connection.

layer_pattern(opts \\ [])

@spec layer_pattern(keyword()) :: [atom()]

Get the layer pattern for a given configuration.

Returns a list describing each layer type for debugging/visualization.

Examples

iex> Hybrid.layer_pattern(num_layers: 6, attention_every: 3)
[:mamba, :mamba, :attention, :mamba, :mamba, :attention]

iex> Hybrid.layer_pattern(num_layers: 6, attention_every: 3, backbone: :gru)
[:gru, :gru, :attention, :gru, :gru, :attention]

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a hybrid model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a hybrid model.