Edifice.SSM.Zamba (Edifice v0.2.0)

Copy Markdown View Source

Zamba: Mamba with Single Shared Attention layer.

Named after Zyphra's Zamba architecture, this module implements a more efficient hybrid than Jamba by using a single shared attention layer that is applied at regular intervals throughout the Mamba stack.

Key Difference from Jamba

AspectJambaZamba
Attention layersMultiple (interleaved)One (shared weights, reused)
KV cacheO(L x N_attn)O(L) - 10x reduction
ParametersHigherLower
Pattern[M, M, A, M, M, A][M, M, M, M, M, M] -> A (shared)

Architecture Pattern

Input [batch, seq_len, embed_dim]
      
      

  Mamba Block 1                       

  Mamba Block 2                       

  ...                                 

  Mamba Block N                       

               
               

  Shared Attention (applied every K)  <- Same weights, reused

               
               
[batch, hidden_size]

Why Single Shared Attention Works

The insight from Zamba: attention layers primarily serve to propagate information globally, not to learn diverse patterns. A single layer with shared weights can achieve similar global information flow at a fraction of the parameter cost.

Usage

# Default Zamba (6 Mamba layers, 1 shared attention applied every 3)
model = Zamba.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  attention_every: 3
)

# Minimal attention variant (attention only at the end)
model = Zamba.build(
  embed_dim: 287,
  num_layers: 6,
  attention_every: 6  # Only applied after final layer
)

Reference

  • Paper: "Zamba: A Compact 7B SSM Hybrid Model" (arXiv:2405.16712)
  • Key insight: Single shared attention achieves 10x KV cache reduction

Summary

Types

Options for build/1.

Functions

Apply the single shared attention layer.

Build a Zamba model (Mamba + Single Shared Attention).

Build a Mamba layer with residual connection.

Compare Zamba vs Jamba parameter counts.

Get the layer pattern for a given configuration.

Get the output size of a Zamba model.

Calculate approximate parameter count for a Zamba model.

Get recommended defaults for real-time sequence processing.

Types

build_opt()

@type build_opt() ::
  {:attention_every, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

apply_shared_attention(input, opts)

@spec apply_shared_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Apply the single shared attention layer.

The key insight: by using the same layer name ("shared_attention"), Axon will reuse the same parameters for all applications of this layer. This is what makes Zamba different from Jamba - one set of attention weights, applied multiple times.

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Zamba model (Mamba + Single Shared Attention).

Options

Architecture:

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of Mamba layers (default: 6)
  • :attention_every - Apply shared attention every N Mamba layers (default: 3)

Mamba-specific:

  • :state_size - SSM state dimension (default: 16)
  • :expand_factor - Mamba expansion factor (default: 2)
  • :conv_size - Causal conv kernel size (default: 4)

Shared Attention:

  • :num_heads - Number of attention heads (default: 4)
  • :head_dim - Dimension per attention head (default: 64)
  • :window_size - Attention window size (default: 60)

General:

  • :dropout - Dropout rate (default: 0.1)
  • :seq_len - Fixed sequence length for JIT optimization (default: window_size)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

Example

model = Zamba.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  attention_every: 3  # Shared attention applied 2x total
)

build_mamba_layer(input, opts)

@spec build_mamba_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a Mamba layer with residual connection.

Uses Pre-LayerNorm for stability.

compare_to_jamba(opts)

@spec compare_to_jamba(keyword()) :: map()

Compare Zamba vs Jamba parameter counts.

Shows the parameter savings from using shared attention.

Example

iex> Zamba.compare_to_jamba(embed_dim: 287, num_layers: 6)
%{
  zamba_params: 450_000,
  jamba_params: 600_000,
  savings: 150_000,
  savings_percent: 25.0
}

layer_pattern(opts \\ [])

@spec layer_pattern(keyword()) :: [atom()]

Get the layer pattern for a given configuration.

Returns a list describing each layer type for debugging/visualization. The shared attention is marked with "(shared)" to distinguish from Jamba.

Example

iex> Zamba.layer_pattern(num_layers: 6, attention_every: 3)
[:mamba, :mamba, :mamba_attention, :mamba, :mamba, :mamba_attention]

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Zamba model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Zamba model.

Note: Zamba has fewer parameters than Jamba because attention weights are shared (counted once, not per layer).