Edifice.SSM.Mamba3 (Edifice v0.2.0)

Copy Markdown View Source

Mamba-3: Advanced Selective State Space Model with complex state dynamics.

Extends the Mamba architecture with three key innovations from "Mamba-3: Advancing State Space Models" for improved expressiveness and efficiency.

Key Innovations

1. Complex-Valued State Dynamics

State dimensions are paired and rotated by data-dependent angles (theta), similar to how RoPE encodes position. Since Nx has no native complex support, this is implemented as real-valued 2x2 rotation matrices on paired dimensions:

[h_{2i}  ]     [cos(θ)  -sin(θ)] [h_{2i}  ]
[h_{2i+1}]  =  [sin(θ)   cos(θ)] [h_{2i+1}]  * decay + input

2. Generalized Trapezoidal Discretization

Instead of Euler discretization, uses a weighted blend of current and previous inputs controlled by a data-dependent lambda:

h_t = A_bar * h_{t-1} + λ * dt * B_t * x_t + (1-λ) * dt * A_bar * B_{t-1} * x_{t-1}

This reduces discretization error and improves long-range modeling.

3. MIMO Rank-r Updates

Replaces the rank-1 outer product B * x^T with a rank-r product B_r @ X_r^T, increasing arithmetic intensity for better hardware utilization on modern GPUs/TPUs.

Architecture

Same gated block structure as Mamba, with the enhanced SSM core:

Input [batch, seq_len, embed_dim]
      
      

         Mamba-3 Block               
   Linear (expand)         
                                   
     DepthwiseConv + SiLU           
                                   
     Complex SSM + Trap.  Linear+SiLU
     + MIMO rank-r                  
                                   
   multiply         
                                     
         Linear (project)             

      
       (repeat for num_layers)

Usage

model = Mamba3.build(
  embed_dim: 287,
  hidden_size: 256,
  state_size: 16,
  num_layers: 2,
  rank: 4,
  complex: true
)

References

  • "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023)
  • "Transformers are SSMs: Generalized Models and Efficient Algorithms" (Dao & Gu, 2024)

Summary

Types

Options for build/1.

Functions

Build a Mamba-3 model for sequence processing.

Build a single Mamba-3 block with enhanced SSM.

Build the Mamba-3 SSM with complex dynamics, trapezoidal discretization, and MIMO rank-r updates.

Get the output size of a Mamba-3 model.

Get recommended defaults for Mamba-3.

Types

build_opt()

@type build_opt() ::
  {:complex, boolean()}
  | {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:rank, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Mamba-3 model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension D (default: 256)
  • :state_size - SSM state dimension N (default: 16)
  • :expand_factor - Expansion factor E for inner dim (default: 2)
  • :conv_size - 1D convolution kernel size (default: 4)
  • :num_layers - Number of Mamba-3 blocks (default: 2)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)
  • :rank - MIMO rank for input updates (default: 4)
  • :complex - Enable complex-valued state dynamics (default: true)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_mamba3_block(input, opts \\ [])

@spec build_mamba3_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single Mamba-3 block with enhanced SSM.

build_mamba3_ssm(input, opts \\ [])

@spec build_mamba3_ssm(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Mamba-3 SSM with complex dynamics, trapezoidal discretization, and MIMO rank-r updates.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Mamba-3 model.