Edifice.SSM.MambaCumsum (Edifice v0.2.0)

Copy Markdown View Source

Mamba variant for experimenting with alternative scan algorithms.

Currently uses Blelloch scan (same as regular Mamba). This module exists to test alternative approaches like:

  • Hillis-Steele scan: O(L log L) work but more parallelism per level
  • SSD algorithm: Mamba-2's chunked matmul approach for tensor cores
  • Chunked scan: Process in chunks with inter-chunk recurrence

Current Status

The cumsum-based approach (log-space reformulation) doesn't work well in XLA. XLA's cumulative_sum kernel is slower than Blelloch's pad/slice/multiply pattern for this tensor structure.

Usage

model = MambaCumsum.build(embed_dim: 287, hidden_size: 256)

Summary

Types

Options for build/1.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a MambaCumsum model for sequence processing.

Same API as Mamba.build/1.

build_selective_ssm(input, opts \\ [])

@spec build_selective_ssm(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the SSM with configurable scan algorithm.

This is where we can swap in different scan implementations:

  • :blelloch (default) - Work-efficient O(L) work, O(log L) depth
  • :cumsum_transposed - Log-space reformulation with transposed cumsum
  • :cumsum_logspace - Log-space reformulation on original axis ordering

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

See Edifice.SSM.Common.output_size/1.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

See Edifice.SSM.Common.param_count/1.