Flexible hybrid architecture builder for combining different layer types.
This module provides a declarative way to construct hybrid models by specifying a sequence of layer types. Unlike Jamba/Zamba which have fixed patterns, this allows arbitrary combinations:
Supported Layer Types
| Type | Module | Complexity | Best For |
|---|---|---|---|
:mamba | GatedSSM | O(L) | Long sequences |
:attention | Attention | O(L^2) | Global context |
:gla | GLA | O(L) | Fast linear attention |
:rwkv | RWKV | O(L) | Linear RNN |
:ffn | Dense+GELU | O(1) | Feature transform |
:kan | KAN | O(1) | Learnable activations |
Usage
# Custom hybrid: [Mamba, Mamba, Attention, Mamba, GLA, FFN]
pattern = [:mamba, :mamba, :attention, :mamba, :gla, :ffn]
model = HybridBuilder.build(pattern, embed_dim: 287, hidden_size: 256)
# With shared layers (like Zamba)
model = HybridBuilder.build(
[:mamba, :mamba, :mamba, :mamba, :mamba, :mamba],
embed_dim: 287,
shared_layers: %{attention: [2, 4, 6]} # Apply shared attention after these
)Predefined Patterns
HybridBuilder.pattern(:jamba_like, 6) # [M, A, M, A, M, A]
HybridBuilder.pattern(:zamba_like, 6) # [M, M, M, M, M, M] + shared attn
HybridBuilder.pattern(:mamba_gla, 6) # [M, M, GLA, M, M, GLA]
HybridBuilder.pattern(:full_hybrid, 6) # [M, A, GLA, RWKV, M, A]
Summary
Functions
Build a hybrid model from a layer pattern.
Build with a named pattern.
Estimate parameter count for a hybrid model.
Get a predefined layer pattern.
Visualize a layer pattern as a string diagram.
Types
@type layer_type() :: :mamba | :attention | :gla | :rwkv | :ffn | :kan
@type pattern() :: [layer_type()]
Functions
Build a hybrid model from a layer pattern.
Options
Required:
:embed_dim- Input embedding dimension
Architecture:
:hidden_size- Internal dimension (default: 256):dropout- Dropout rate (default: 0.1):shared_layers- Map of layer_type => positions for shared weights
Layer-specific options (prefixed by layer type):
:mamba_state_size- SSM state dimension (default: 16):mamba_expand_factor- Expansion factor (default: 2):attention_num_heads- Number of attention heads (default: 4):attention_window_size- Sliding window size (default: 60):gla_num_heads- GLA heads (default: 4):rwkv_head_size- RWKV head size (default: 64):kan_grid_size- KAN grid size (default: 5)
Returns
An Axon model outputting [batch, hidden_size].
@spec build_pattern(atom(), pos_integer(), keyword()) :: Axon.t()
Build with a named pattern.
Convenience function combining pattern/2 and build/2.
HybridBuilder.build_pattern(:jamba_like, 6, embed_dim: 287)
@spec param_count( pattern(), keyword() ) :: non_neg_integer()
Estimate parameter count for a hybrid model.
@spec pattern(atom(), pos_integer()) :: pattern()
Get a predefined layer pattern.
Available Patterns
:jamba_like- Interleaved Mamba + Attention:zamba_like- All Mamba (use with shared_layers):mamba_gla- Mamba + Gated Linear Attention:rwkv_attention- RWKV + Sparse Attention:full_hybrid- Mix of all layer types:ssm_stack- Pure SSM (Mamba only)
Visualize a layer pattern as a string diagram.