Zamba: Mamba with Single Shared Attention layer.
Named after Zyphra's Zamba architecture, this module implements a more efficient hybrid than Jamba by using a single shared attention layer that is applied at regular intervals throughout the Mamba stack.
Key Difference from Jamba
| Aspect | Jamba | Zamba |
|---|---|---|
| Attention layers | Multiple (interleaved) | One (shared weights, reused) |
| KV cache | O(L x N_attn) | O(L) - 10x reduction |
| Parameters | Higher | Lower |
| Pattern | [M, M, A, M, M, A] | [M, M, M, M, M, M] -> A (shared) |
Architecture Pattern
Input [batch, seq_len, embed_dim]
│
▼
┌─────────────────────────────────────┐
│ Mamba Block 1 │
├─────────────────────────────────────┤
│ Mamba Block 2 │
├─────────────────────────────────────┤
│ ... │
├─────────────────────────────────────┤
│ Mamba Block N │
└──────────────┬──────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ Shared Attention (applied every K) │ <- Same weights, reused
└──────────────┬──────────────────────┘
│
▼
[batch, hidden_size]Why Single Shared Attention Works
The insight from Zamba: attention layers primarily serve to propagate information globally, not to learn diverse patterns. A single layer with shared weights can achieve similar global information flow at a fraction of the parameter cost.
Usage
# Default Zamba (6 Mamba layers, 1 shared attention applied every 3)
model = Zamba.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
attention_every: 3
)
# Minimal attention variant (attention only at the end)
model = Zamba.build(
embed_dim: 287,
num_layers: 6,
attention_every: 6 # Only applied after final layer
)Reference
- Paper: "Zamba: A Compact 7B SSM Hybrid Model" (arXiv:2405.16712)
- Key insight: Single shared attention achieves 10x KV cache reduction
Summary
Functions
Apply the single shared attention layer.
Build a Zamba model (Mamba + Single Shared Attention).
Build a Mamba layer with residual connection.
Compare Zamba vs Jamba parameter counts.
Get the layer pattern for a given configuration.
Get the output size of a Zamba model.
Calculate approximate parameter count for a Zamba model.
Get recommended defaults for real-time sequence processing.
Types
@type build_opt() :: {:attention_every, pos_integer()} | {:conv_size, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:head_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:state_size, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Zamba model (Mamba + Single Shared Attention).
Options
Architecture:
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of Mamba layers (default: 6):attention_every- Apply shared attention every N Mamba layers (default: 3)
Mamba-specific:
:state_size- SSM state dimension (default: 16):expand_factor- Mamba expansion factor (default: 2):conv_size- Causal conv kernel size (default: 4)
Shared Attention:
:num_heads- Number of attention heads (default: 4):head_dim- Dimension per attention head (default: 64):window_size- Attention window size (default: 60)
General:
:dropout- Dropout rate (default: 0.1):seq_len- Fixed sequence length for JIT optimization (default: window_size)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Example
model = Zamba.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
attention_every: 3 # Shared attention applied 2x total
)
Build a Mamba layer with residual connection.
Uses Pre-LayerNorm for stability.
Compare Zamba vs Jamba parameter counts.
Shows the parameter savings from using shared attention.
Example
iex> Zamba.compare_to_jamba(embed_dim: 287, num_layers: 6)
%{
zamba_params: 450_000,
jamba_params: 600_000,
savings: 150_000,
savings_percent: 25.0
}
Get the layer pattern for a given configuration.
Returns a list describing each layer type for debugging/visualization. The shared attention is marked with "(shared)" to distinguish from Jamba.
Example
iex> Zamba.layer_pattern(num_layers: 6, attention_every: 3)
[:mamba, :mamba, :mamba_attention, :mamba, :mamba, :mamba_attention]
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Zamba model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Zamba model.
Note: Zamba has fewer parameters than Jamba because attention weights are shared (counted once, not per layer).
@spec recommended_defaults() :: keyword()
Get recommended defaults for real-time sequence processing.
Optimized for:
- Real-time inference
- 1-second context window
- Minimal memory footprint (shared attention)