Mamba: True Selective State Space Model with optimized parallel scan.
Implements the Mamba architecture from "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023).
Key Innovation: Parallel Associative Scan
The SSM recurrence h[t] = A h[t-1] + B x[t] seems sequential, but can be parallelized using associativity:
Define: (a, b) ⊗ (c, d) = (a*c, a*d + b)
Then the scan:
h[0] = B[0] * x[0]
h[1] = A[1] * h[0] + B[1] * x[1]
h[2] = A[2] * h[1] + B[2] * x[2]
...
Can be computed in O(log L) parallel time using prefix scan.Selective Mechanism
Unlike linear time-invariant SSMs, Mamba makes A, B, C input-dependent:
- Δ (discretization step) controls how much to update state
- B (input matrix) projects input to state space
- C (output matrix) projects state to output
- These are computed from the input, enabling selective focus
Architecture
Input [batch, seq_len, embed_dim]
│
▼
┌─────────────────────────────────────┐
│ Mamba Block │
│ │
│ ┌──── Linear (expand) ────┐ │
│ │ │ │ │
│ │ DepthwiseConv + SiLU │ │
│ │ │ │ │
│ │ Parallel Scan SSM Linear+SiLU │
│ │ │ │ │
│ └───────── multiply ───────┘ │
│ │ │
│ Linear (project) │
└─────────────────────────────────────┘
│
▼ (repeat for num_layers)Usage
# Build Mamba backbone
model = Mamba.build(
embed_dim: 287,
hidden_size: 256,
state_size: 16,
num_layers: 2,
expand_factor: 2
)References
- Paper: https://arxiv.org/abs/2312.00752
- Original code: https://github.com/state-spaces/mamba
Summary
Functions
Build a Mamba model for sequence processing.
Build depthwise separable 1D convolution layer.
Build a single Mamba block with parallel scan SSM.
Build the Selective SSM with parallel associative scan.
Get the output size of a Mamba model.
Calculate approximate parameter count for a Mamba model.
Get recommended defaults for real-time sequence processing (60fps).
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:state_size, pos_integer()} | {:expand_factor, pos_integer()} | {:conv_size, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Mamba model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension D (default: 256):state_size- SSM state dimension N (default: 16):expand_factor- Expansion factor E for inner dim (default: 2):conv_size- 1D convolution kernel size (default: 4):num_layers- Number of Mamba blocks (default: 2):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec build_depthwise_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()
Build depthwise separable 1D convolution layer.
Build a single Mamba block with parallel scan SSM.
Options
:hidden_size- Internal dimension D:state_size- SSM state dimension N:expand_factor- Expansion factor E:conv_size- Convolution kernel size:name- Layer name prefix
Build the Selective SSM with parallel associative scan.
This is the core of Mamba: an SSM where A, B, C, Δ are input-dependent, computed efficiently using parallel scan.
The discretized SSM equations:
- A_bar = exp(Δ * A)
- B_bar = Δ * B
- h[t] = A_bar h[t-1] + B_bar x[t]
- y[t] = C * h[t]
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Mamba model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Mamba model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for real-time sequence processing (60fps).