Shared components for all Mamba architecture variants.
This module contains the common building blocks used across Mamba variants:
- Default hyperparameters
- Model structure builders (input projection, layer stacking, last timestep)
- Block structure (normalization, projections, gating)
- Depthwise convolution
- SSM parameter projections
- SSM discretization
- Utility functions
Mamba Variants
All variants share the same architecture, differing only in the scan algorithm:
| Variant | Scan Algorithm | Notes |
|---|---|---|
Mamba | Blelloch | Work-efficient O(L) work, O(log L) depth |
MambaHillisSteele | Hillis-Steele | O(L log L) work, more parallelism |
MambaCumsum | Cumsum-based | Experimental log-space approach |
MambaSSD | SSD chunked | Mamba-2's matmul approach |
See Also
Edifice.SSM.Mamba- Main Mamba implementation
Summary
Functions
Blelloch parallel scan (work-efficient O(L) work, O(log L) depth).
Build the common Mamba block structure.
Build a depthwise separable 1D convolution layer.
Build the common Mamba model structure.
Build the SSM parameter projections (B, C, dt).
Compute the SSM output from hidden states.
Default convolution kernel size
Default dropout rate
Default expansion factor E
Default hidden dimension D
Default number of Mamba blocks
Default SSM state dimension N
Discretize the SSM parameters for the scan.
Maximum delta
Minimum delta for numerical stability
Get the output size of a Mamba model.
Calculate approximate parameter count for a Mamba model.
Get recommended defaults for real-time sequence processing (60fps).
Sequential scan for short sequences or fallback.
Functions
@spec blelloch_scan(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Blelloch parallel scan (work-efficient O(L) work, O(log L) depth).
Uses Enum.reduce for the loop - this lets XLA JIT each level efficiently.
Parameters
a- Decay factors[batch, seq_len, hidden_size, state_size]b- Input contributions[batch, seq_len, hidden_size, state_size]
Returns
Hidden states [batch, seq_len, hidden_size, state_size]
Build the common Mamba block structure.
This handles everything except the SSM scan itself:
- Layer normalization
- Input projection (to 2x inner_size for x/z branches)
- X/Z branch splitting
- Depthwise convolution + SiLU on X branch
- SiLU gating on Z branch
- Gated multiplication
- Output projection
The caller provides an ssm_builder function that constructs the SSM layer.
Parameters
input- Input Axon nodeopts- Block options (hidden_size, state_size, expand_factor, conv_size, name)ssm_builder- Function(x_activated, ssm_opts) -> Axon.t()that builds SSM
Returns
An Axon node representing the block output.
@spec build_depthwise_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()
Build a depthwise separable 1D convolution layer.
True Mamba uses learned depthwise convolution, not mean pooling. This approximates depthwise conv behavior for SSM input processing.
Parameters
input- Input Axon node[batch, seq_len, channels]channels- Number of output channelskernel_size- Convolution kernel sizename- Layer name prefix
Build the common Mamba model structure.
This handles:
- Input projection (if embed_dim != hidden_size)
- Layer stacking with residual connections and dropout
- Last timestep extraction
The caller provides a block_builder function that constructs each Mamba block.
Parameters
opts- Model options (embed_dim, hidden_size, num_layers, dropout, etc.)block_builder- Function(input, opts) -> Axon.t()that builds one block
Returns
An Axon model that outputs [batch, hidden_size].
Build the SSM parameter projections (B, C, dt).
These are the "selective" parameters that make Mamba input-dependent:
- B: Input matrix
[batch, seq_len, state_size] - C: Output matrix
[batch, seq_len, state_size] - dt: Discretization step
[batch, seq_len, hidden_size]
Parameters
input- Input Axon node[batch, seq_len, hidden_size]opts- Options (hidden_size, state_size, dt_rank, name)
Returns
Tuple of {b_matrix, c_matrix, dt_proj} Axon nodes.
@spec compute_ssm_output(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the SSM output from hidden states.
y[t] = C[t] * h[t]
Parameters
h- Hidden states[batch, seq_len, hidden_size, state_size]c- C matrix[batch, seq_len, state_size]
Returns
Output tensor [batch, seq_len, hidden_size]
@spec default_conv_size() :: pos_integer()
Default convolution kernel size
@spec default_dropout() :: float()
Default dropout rate
@spec default_expand_factor() :: pos_integer()
Default expansion factor E
@spec default_num_layers() :: pos_integer()
Default number of Mamba blocks
@spec default_state_size() :: pos_integer()
Default SSM state dimension N
@spec discretize_ssm(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), pos_integer()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Discretize the SSM parameters for the scan.
Converts continuous-time SSM to discrete-time:
- A_bar = exp(Δ * A)
- B_bar = Δ * B
- Bx = B_bar * x
Parameters
x- Input tensor[batch, seq_len, hidden_size]b- B matrix[batch, seq_len, state_size]dt- Delta tensor[batch, seq_len, hidden_size]state_size- SSM state dimension
Returns
Tuple of {a_bar, bx} where:
a_bar:[batch, seq_len, hidden_size, state_size]- decay factorsbx:[batch, seq_len, hidden_size, state_size]- input contributions
@spec dt_max() :: float()
Maximum delta
@spec dt_min() :: float()
Minimum delta for numerical stability
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Mamba model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Mamba model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for real-time sequence processing (60fps).
@spec sequential_scan(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Sequential scan for short sequences or fallback.
Computes h[t] = a[t] * h[t-1] + b[t] for all t.
Parameters
a- Decay factors[batch, seq_len, hidden_size, state_size]b- Input contributions[batch, seq_len, hidden_size, state_size]
Returns
Hidden states [batch, seq_len, hidden_size, state_size]