Edifice.SSM.HybridBuilder (Edifice v0.2.0)

Copy Markdown View Source

Flexible hybrid architecture builder for combining different layer types.

This module provides a declarative way to construct hybrid models by specifying a sequence of layer types. Unlike Jamba/Zamba which have fixed patterns, this allows arbitrary combinations:

Supported Layer Types

TypeModuleComplexityBest For
:mambaGatedSSMO(L)Long sequences
:attentionAttentionO(L^2)Global context
:glaGLAO(L)Fast linear attention
:rwkvRWKVO(L)Linear RNN
:ffnDense+GELUO(1)Feature transform
:kanKANO(1)Learnable activations

Usage

# Custom hybrid: [Mamba, Mamba, Attention, Mamba, GLA, FFN]
pattern = [:mamba, :mamba, :attention, :mamba, :gla, :ffn]
model = HybridBuilder.build(pattern, embed_dim: 287, hidden_size: 256)

# With shared layers (like Zamba)
model = HybridBuilder.build(
  [:mamba, :mamba, :mamba, :mamba, :mamba, :mamba],
  embed_dim: 287,
  shared_layers: %{attention: [2, 4, 6]}  # Apply shared attention after these
)

Predefined Patterns

HybridBuilder.pattern(:jamba_like, 6)   # [M, A, M, A, M, A]
HybridBuilder.pattern(:zamba_like, 6)   # [M, M, M, M, M, M] + shared attn
HybridBuilder.pattern(:mamba_gla, 6)    # [M, M, GLA, M, M, GLA]
HybridBuilder.pattern(:full_hybrid, 6)  # [M, A, GLA, RWKV, M, A]

Summary

Functions

Build a hybrid model from a layer pattern.

Build with a named pattern.

Estimate parameter count for a hybrid model.

Get a predefined layer pattern.

Visualize a layer pattern as a string diagram.

Types

layer_type()

@type layer_type() :: :mamba | :attention | :gla | :rwkv | :ffn | :kan

pattern()

@type pattern() :: [layer_type()]

Functions

build(pattern, opts)

@spec build(
  pattern(),
  keyword()
) :: Axon.t()

Build a hybrid model from a layer pattern.

Options

Required:

  • :embed_dim - Input embedding dimension

Architecture:

  • :hidden_size - Internal dimension (default: 256)
  • :dropout - Dropout rate (default: 0.1)
  • :shared_layers - Map of layer_type => positions for shared weights

Layer-specific options (prefixed by layer type):

  • :mamba_state_size - SSM state dimension (default: 16)
  • :mamba_expand_factor - Expansion factor (default: 2)
  • :attention_num_heads - Number of attention heads (default: 4)
  • :attention_window_size - Sliding window size (default: 60)
  • :gla_num_heads - GLA heads (default: 4)
  • :rwkv_head_size - RWKV head size (default: 64)
  • :kan_grid_size - KAN grid size (default: 5)

Returns

An Axon model outputting [batch, hidden_size].

build_pattern(pattern_name, num_layers, opts)

@spec build_pattern(atom(), pos_integer(), keyword()) :: Axon.t()

Build with a named pattern.

Convenience function combining pattern/2 and build/2.

HybridBuilder.build_pattern(:jamba_like, 6, embed_dim: 287)

param_count(pattern, opts)

@spec param_count(
  pattern(),
  keyword()
) :: non_neg_integer()

Estimate parameter count for a hybrid model.

pattern(atom, num_layers)

@spec pattern(atom(), pos_integer()) :: pattern()

Get a predefined layer pattern.

Available Patterns

  • :jamba_like - Interleaved Mamba + Attention
  • :zamba_like - All Mamba (use with shared_layers)
  • :mamba_gla - Mamba + Gated Linear Attention
  • :rwkv_attention - RWKV + Sparse Attention
  • :full_hybrid - Mix of all layer types
  • :ssm_stack - Pure SSM (Mamba only)

visualize(pattern)

@spec visualize(pattern()) :: String.t()

Visualize a layer pattern as a string diagram.