Configurable Hybrid Backbone+Attention architecture for efficient sequence modeling.
Originally based on AI21's Jamba architecture, this module interleaves a configurable backbone (Mamba by default) with periodic attention layers. The backbone handles local/sequential context efficiently, while attention captures long-range dependencies.
Supported Backbones
| Backbone | Module | Key Mechanism |
|---|---|---|
:mamba (default) | GatedSSM | Selective state space model |
:gru | Axon.gru | Gated recurrent unit |
:rwkv | RWKV | Linear attention (WKV) |
:delta_net | DeltaNet | Delta rule linear attention |
:gated_delta_net | GatedDeltaNet | Gated delta rule |
:griffin_lru | Griffin RG-LRU | Real-gated linear recurrent unit |
:custom | User-provided | Any (Axon.t(), keyword()) -> Axon.t() |
Architecture Pattern
Input [batch, seq_len, embed_dim]
│
▼
┌─────────────────────────────────────┐
│ Backbone Block 1 │
├─────────────────────────────────────┤
│ Backbone Block 2 │
├─────────────────────────────────────┤
│ Backbone Block 3 │
├─────────────────────────────────────┤
│ Attention Block (every N layers) │ <- Long-range dependencies
├─────────────────────────────────────┤
│ Backbone Block 4 │
├─────────────────────────────────────┤
│ ... │
└─────────────────────────────────────┘
│
▼
[batch, hidden_size]Key Advantages
- Efficiency: Most layers are O(L) backbone blocks
- Long-range: Periodic attention captures distant dependencies
- Flexible: Swap backbone without changing the hybrid structure
- Memory: Far less than pure attention, slightly more than pure backbone
Usage
# Default hybrid (Mamba backbone, 3:1 ratio)
model = Hybrid.build(
embed_dim: 256,
hidden_size: 256,
num_layers: 8,
attention_every: 4
)
# GRU backbone (classic RNN + attention)
model = Hybrid.build(
embed_dim: 256,
num_layers: 6,
backbone: :gru,
attention_every: 3
)
# Gated DeltaNet backbone (linear attention + full attention)
model = Hybrid.build(
embed_dim: 256,
num_layers: 6,
backbone: :gated_delta_net,
attention_every: 3
)
Summary
Functions
Build a hybrid backbone+attention model.
Build an attention layer with residual connection and FFN.
Build a backbone layer based on the configured backbone type.
Build a DeltaNet backbone layer. Delegates to DeltaNet.build_block/2
which includes pre-norm and residual connection.
Build a Gated DeltaNet backbone layer with pre-norm and residual connection.
Build a GRU backbone layer with pre-norm and residual connection.
Build a Mamba layer with residual connection.
Build an RWKV backbone layer with pre-norm and residual connection.
Get the layer pattern for a given configuration.
Get the output size of a hybrid model.
Calculate approximate parameter count for a hybrid model.
Get recommended defaults for real-time sequence processing.
Types
@type build_opt() :: {:attention_every, pos_integer()} | {:backbone, atom() | {module(), atom()}} | {:chunk_size, pos_integer()} | {:chunked_attention, boolean()} | {:conv_size, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:head_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:memory_efficient_attention, boolean()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:pre_norm, boolean()} | {:qk_layernorm, boolean()} | {:seq_len, pos_integer()} | {:state_size, pos_integer()} | {:use_sliding_window, boolean()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a hybrid backbone+attention model.
Options
Architecture:
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Total number of layers (default: 6):attention_every- Insert attention every N layers (default: 3):backbone- Backbone type for non-attention layers (default::mamba). Supported::mamba,:gru,:rwkv,:delta_net,:gated_delta_net,:griffin_lru, or a{module, function}tuple for custom backbones. Custom functions must accept(input :: Axon.t(), opts :: keyword()) :: Axon.t().
Mamba-specific (when backbone: :mamba):
:state_size- SSM state dimension (default: 16):expand_factor- Mamba expansion factor (default: 2):conv_size- Causal conv kernel size (default: 4)
Attention-specific:
:num_heads- Number of attention heads (default: 4):head_dim- Dimension per attention head (default: 64):window_size- Attention window size (default: 60):use_sliding_window- Use sliding window vs full attention (default: true):qk_layernorm- Normalize Q and K before attention (default: true, stabilizes training)
General:
:dropout- Dropout rate (default: 0.1):seq_len- Fixed sequence length for JIT optimization (default: window_size):pre_norm- Use Pre-LayerNorm (default: true, more stable than Post-LayerNorm)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Examples
# Mamba backbone (default, Jamba-style)
model = Hybrid.build(
embed_dim: 256,
hidden_size: 256,
num_layers: 6,
attention_every: 3
)
# GRU backbone
model = Hybrid.build(
embed_dim: 256,
num_layers: 6,
backbone: :gru,
attention_every: 3
)
Build an attention layer with residual connection and FFN.
Options
:pre_norm- If true, apply LayerNorm before block (Pre-LN, more stable).:qk_layernorm- If true, normalize Q and K before attention (stabilizes training).:chunked- If true, use chunked attention for lower memory usage (default: false).:memory_efficient- If true, use memory-efficient attention with online softmax for true O(n) memory (default: false).:chunk_size- Chunk size when using chunked or memory-efficient attention (default: 32).
Build a backbone layer based on the configured backbone type.
Dispatches to the appropriate block builder. All backbone types follow
the same contract: (Axon.t(), keyword()) -> Axon.t(), taking
[batch, seq, hidden] input and returning the same shape with a residual connection.
Supported Backbones
:mamba- GatedSSM Mamba block (default):gru- GRU recurrent layer:rwkv- RWKV linear attention block:delta_net- DeltaNet delta rule block:gated_delta_net- Gated DeltaNet block:griffin_lru- Griffin RG-LRU block{module, function}- Custom backbone; called asmodule.function(input, opts)
Build a DeltaNet backbone layer. Delegates to DeltaNet.build_block/2
which includes pre-norm and residual connection.
Build a Gated DeltaNet backbone layer with pre-norm and residual connection.
Build a GRU backbone layer with pre-norm and residual connection.
Wraps Axon.gru in the same pre-norm + residual pattern as Mamba layers.
Build a Mamba layer with residual connection.
Options
:pre_norm- If true, apply LayerNorm before block (Pre-LN, more stable).If false, apply after residual (Post-LN, original transformer style).
Build an RWKV backbone layer with pre-norm and residual connection.
Get the layer pattern for a given configuration.
Returns a list describing each layer type for debugging/visualization.
Examples
iex> Hybrid.layer_pattern(num_layers: 6, attention_every: 3)
[:mamba, :mamba, :attention, :mamba, :mamba, :attention]
iex> Hybrid.layer_pattern(num_layers: 6, attention_every: 3, backbone: :gru)
[:gru, :gru, :attention, :gru, :gru, :attention]
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a hybrid model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a hybrid model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for real-time sequence processing.
Optimized for:
- Real-time inference (~10ms budget)
- 1-second context window
- Balance between local patterns (Mamba) and long-range context (Attention)
- Training stability (pre-norm + QK LayerNorm)