Edifice.SSM.Mamba (Edifice v0.2.0)

Copy Markdown View Source

Mamba: True Selective State Space Model with optimized parallel scan.

Implements the Mamba architecture from "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023).

Key Innovation: Parallel Associative Scan

The SSM recurrence h[t] = A h[t-1] + B x[t] seems sequential, but can be parallelized using associativity:

Define: (a, b)  (c, d) = (a*c, a*d + b)

Then the scan:
  h[0] = B[0] * x[0]
  h[1] = A[1] * h[0] + B[1] * x[1]
  h[2] = A[2] * h[1] + B[2] * x[2]
  ...

Can be computed in O(log L) parallel time using prefix scan.

Selective Mechanism

Unlike linear time-invariant SSMs, Mamba makes A, B, C input-dependent:

  • Δ (discretization step) controls how much to update state
  • B (input matrix) projects input to state space
  • C (output matrix) projects state to output
  • These are computed from the input, enabling selective focus

Architecture

Input [batch, seq_len, embed_dim]
      
      

         Mamba Block                  
                                      
   Linear (expand)         
                                   
     DepthwiseConv + SiLU           
                                   
     Parallel Scan SSM  Linear+SiLU  
                                   
   multiply         
                                     
         Linear (project)             

      
       (repeat for num_layers)

Usage

# Build Mamba backbone
model = Mamba.build(
  embed_dim: 287,
  hidden_size: 256,
  state_size: 16,
  num_layers: 2,
  expand_factor: 2
)

References

Summary

Types

Options for build/1.

Functions

Build a Mamba model for sequence processing.

Build depthwise separable 1D convolution layer.

Build a single Mamba block with parallel scan SSM.

Build the Selective SSM with parallel associative scan.

Get the output size of a Mamba model.

Calculate approximate parameter count for a Mamba model.

Get recommended defaults for real-time sequence processing (60fps).

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Mamba model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension D (default: 256)
  • :state_size - SSM state dimension N (default: 16)
  • :expand_factor - Expansion factor E for inner dim (default: 2)
  • :conv_size - 1D convolution kernel size (default: 4)
  • :num_layers - Number of Mamba blocks (default: 2)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_depthwise_conv1d(input, channels, kernel_size, name)

@spec build_depthwise_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()

Build depthwise separable 1D convolution layer.

build_mamba_block(input, opts \\ [])

@spec build_mamba_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single Mamba block with parallel scan SSM.

Options

  • :hidden_size - Internal dimension D
  • :state_size - SSM state dimension N
  • :expand_factor - Expansion factor E
  • :conv_size - Convolution kernel size
  • :name - Layer name prefix

build_selective_ssm_parallel(input, opts \\ [])

@spec build_selective_ssm_parallel(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Selective SSM with parallel associative scan.

This is the core of Mamba: an SSM where A, B, C, Δ are input-dependent, computed efficiently using parallel scan.

The discretized SSM equations:

  • A_bar = exp(Δ * A)
  • B_bar = Δ * B
  • h[t] = A_bar h[t-1] + B_bar x[t]
  • y[t] = C * h[t]

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Mamba model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Mamba model.