Mamba variant for experimenting with alternative scan algorithms.
Currently uses Blelloch scan (same as regular Mamba). This module exists to test alternative approaches like:
- Hillis-Steele scan: O(L log L) work but more parallelism per level
- SSD algorithm: Mamba-2's chunked matmul approach for tensor cores
- Chunked scan: Process in chunks with inter-chunk recurrence
Current Status
The cumsum-based approach (log-space reformulation) doesn't work well in XLA. XLA's cumulative_sum kernel is slower than Blelloch's pad/slice/multiply pattern for this tensor structure.
Usage
model = MambaCumsum.build(embed_dim: 287, hidden_size: 256)
Summary
Functions
Build a MambaCumsum model for sequence processing.
Build the SSM with configurable scan algorithm.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:state_size, pos_integer()} | {:expand_factor, pos_integer()} | {:conv_size, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a MambaCumsum model for sequence processing.
Same API as Mamba.build/1.
Build the SSM with configurable scan algorithm.
This is where we can swap in different scan implementations:
- :blelloch (default) - Work-efficient O(L) work, O(log L) depth
- :cumsum_transposed - Log-space reformulation with transposed cumsum
- :cumsum_logspace - Log-space reformulation on original axis ordering
@spec output_size(keyword()) :: non_neg_integer()
@spec param_count(keyword()) :: non_neg_integer()
@spec recommended_defaults() :: keyword()