Mamba-3: Advanced Selective State Space Model with complex state dynamics.
Extends the Mamba architecture with three key innovations from "Mamba-3: Advancing State Space Models" for improved expressiveness and efficiency.
Key Innovations
1. Complex-Valued State Dynamics
State dimensions are paired and rotated by data-dependent angles (theta), similar to how RoPE encodes position. Since Nx has no native complex support, this is implemented as real-valued 2x2 rotation matrices on paired dimensions:
[h_{2i} ] [cos(θ) -sin(θ)] [h_{2i} ]
[h_{2i+1}] = [sin(θ) cos(θ)] [h_{2i+1}] * decay + input2. Generalized Trapezoidal Discretization
Instead of Euler discretization, uses a weighted blend of current and previous inputs controlled by a data-dependent lambda:
h_t = A_bar * h_{t-1} + λ * dt * B_t * x_t + (1-λ) * dt * A_bar * B_{t-1} * x_{t-1}This reduces discretization error and improves long-range modeling.
3. MIMO Rank-r Updates
Replaces the rank-1 outer product B * x^T with a rank-r product B_r @ X_r^T, increasing arithmetic intensity for better hardware utilization on modern GPUs/TPUs.
Architecture
Same gated block structure as Mamba, with the enhanced SSM core:
Input [batch, seq_len, embed_dim]
│
▼
┌─────────────────────────────────────┐
│ Mamba-3 Block │
│ ┌──── Linear (expand) ────┐ │
│ │ │ │ │
│ │ DepthwiseConv + SiLU │ │
│ │ │ │ │
│ │ Complex SSM + Trap. Linear+SiLU│
│ │ + MIMO rank-r │ │
│ │ │ │ │
│ └───────── multiply ───────┘ │
│ │ │
│ Linear (project) │
└─────────────────────────────────────┘
│
▼ (repeat for num_layers)Usage
model = Mamba3.build(
embed_dim: 287,
hidden_size: 256,
state_size: 16,
num_layers: 2,
rank: 4,
complex: true
)References
- "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023)
- "Transformers are SSMs: Generalized Models and Efficient Algorithms" (Dao & Gu, 2024)
Summary
Functions
Build a Mamba-3 model for sequence processing.
Build a single Mamba-3 block with enhanced SSM.
Build the Mamba-3 SSM with complex dynamics, trapezoidal discretization, and MIMO rank-r updates.
Get the output size of a Mamba-3 model.
Get recommended defaults for Mamba-3.
Types
@type build_opt() :: {:complex, boolean()} | {:conv_size, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:rank, pos_integer()} | {:state_size, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Mamba-3 model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension D (default: 256):state_size- SSM state dimension N (default: 16):expand_factor- Expansion factor E for inner dim (default: 2):conv_size- 1D convolution kernel size (default: 4):num_layers- Number of Mamba-3 blocks (default: 2):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60):rank- MIMO rank for input updates (default: 4):complex- Enable complex-valued state dynamics (default: true)
Returns
An Axon model that processes sequences and outputs the last hidden state.
Build a single Mamba-3 block with enhanced SSM.
Build the Mamba-3 SSM with complex dynamics, trapezoidal discretization, and MIMO rank-r updates.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Mamba-3 model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for Mamba-3.