Edifice.SSM.MambaSSD (Edifice v0.2.0)

Copy Markdown View Source

Mamba variant using State Space Duality (SSD) algorithm from Mamba-2.

SSD Algorithm

The key insight: SSM computation can be decomposed into matrix multiplications that leverage tensor cores (10-20x faster than scalar operations).

Algorithm Steps

  1. Split into chunks: Divide sequence into chunks of size C
  2. Intra-chunk (matmul): Compute outputs within each chunk using dense matmul
    • This uses tensor cores!
    • O(C²) work per chunk, but highly parallel
  3. Inter-chunk (scan): Small sequential scan over chunk boundaries
    • Only L/C elements to scan
  4. Combine: Merge chunk outputs with inter-chunk states

Complexity

  • Intra-chunk: O(L/C × C²) = O(L × C) work, but tensor core accelerated
  • Inter-chunk: O(L/C) sequential work (tiny)
  • Total: Much faster in practice due to tensor cores

Training Mode

When training_mode: true is set, the SSD algorithm uses matrix multiplication formulation optimized for tensor cores:

y = (L  (C @ B^T)) @ x + cumsum(A) @ h_prev

Where L is a lower-triangular mask. This formulation:

  • Uses dense matmuls for tensor core utilization
  • Computes all positions in parallel within each chunk
  • Is significantly faster for batched training

For inference, use training_mode: false (default) which uses efficient scans with O(1) memory per step.

Current Performance

Note: The XLA implementation has limitations compared to fused CUDA kernels. For production training, consider using a custom Triton kernel.

Usage

# Training (matmul formulation)
model = MambaSSD.build(embed_dim: 287, hidden_size: 256, training_mode: true)

# Inference (scan formulation)
model = MambaSSD.build(embed_dim: 287, hidden_size: 256, training_mode: false)

Summary

Types

Options for build/1.

Functions

Build an SSD Mamba model.

Get recommended defaults for real-time sequence processing (60fps).

Get training-optimized defaults.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
  | {:chunk_size, pos_integer()}
  | {:training_mode, boolean()}
  | {:structured_mask, boolean()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an SSD Mamba model.

Options

  • :training_mode - If true, uses matmul formulation for tensor cores (default: false)
  • :chunk_size - Size of chunks for SSD algorithm (default: 16)
  • :structured_mask - If true, uses structured semi-separable mask that combines causal masking with SSM decay: M[i,j] = prod(a[k], k=j+1..i) for i >= j. This replaces the simple lower-triangular mask. (default: false)
  • All common Mamba options (see Edifice.SSM.Common)

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

See Edifice.SSM.Common.output_size/1.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

See Edifice.SSM.Common.param_count/1.

training_defaults()

@spec training_defaults() :: keyword()

Get training-optimized defaults.

Uses matmul formulation for better tensor core utilization.