GatedSSM: Simplified gated temporal network inspired by state space models.
NOTE: This is NOT a true Mamba implementation. It uses a simplified gating
mechanism instead of the parallel associative scan that makes Mamba efficient.
For true Mamba, see Edifice.SSM.Mamba.
This module provides competitive results and is numerically stable. Use it when you want a lightweight temporal model that's simpler than true Mamba.
How It Differs From True Mamba
| Aspect | True Mamba | GatedSSM |
|---|---|---|
| Core algorithm | Parallel associative scan | Gated multiplication |
| Recurrence | h(t) = Ah(t-1) + Bx | Sigmoid gating approximation |
| Convolution | Learned depthwise separable | Mean pooling + projection |
| Complexity | O(L) parallel | O(L) sequential approximation |
Architecture
Input [batch, seq_len, embed_dim]
│
▼
┌─────────────────────────────────────┐
│ GatedSSM Block │
│ │
│ ┌──── Linear (expand) ────┐ │
│ │ │ │ │
│ │ MeanPool + SiLU │ │
│ │ │ │ │
│ │ Gated Context Linear+SiLU │
│ │ │ │ │
│ └───────── multiply ───────┘ │
│ │ │
│ Linear (project) │
└─────────────────────────────────────┘
│
▼ (repeat for num_layers)
│
▼
[batch, seq_len, embed_dim] -> last timestep -> [batch, embed_dim]Usage
# Build GatedSSM backbone
model = GatedSSM.build(
embed_dim: 256,
hidden_size: 256,
state_size: 16,
num_layers: 2,
expand_factor: 2
)When To Use
- Lightweight temporal processing without full Mamba complexity
- Stable training (no NaN issues observed)
- When true Mamba isn't available or needed
Summary
Functions
Build a Mamba model for sequence processing.
Build a causal 1D convolution layer.
Build a Mamba model with gradient checkpointing for memory-efficient training.
Build a single Mamba block.
Build the Selective State Space Model (S6).
Initialize hidden state for incremental inference.
Get the output size of a Mamba model.
Calculate approximate parameter count for a Mamba model.
Get recommended defaults for real-time sequence processing (60fps).
Perform a single incremental step with cached state.
Types
@type build_opt() :: {:conv_size, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:state_size, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Mamba model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension D (default: 256):state_size- SSM state dimension N (default: 16):expand_factor- Expansion factor E for inner dim (default: 2):conv_size- 1D convolution kernel size (default: 4):num_layers- Number of Mamba blocks (default: 2):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec build_causal_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()
Build a causal 1D convolution layer.
Applies convolution only over past timesteps (causal padding). Uses a simplified approach with sliding window mean + learned projection.
Build a Mamba model with gradient checkpointing for memory-efficient training.
Same as build/1 but applies gradient checkpointing to each Mamba block,
reducing memory usage at the cost of ~30% more compute.
Memory Savings
For a 3-layer Mamba with window_size=60, batch_size=256:
- Without checkpointing: ~2.5GB activation memory
- With checkpointing: ~0.8GB activation memory
When to Use
- Training on GPU with limited VRAM
- Using large batch sizes or long sequences
- When you're hitting OOM during training
Options
Same as build/1, plus:
:checkpoint_every- Checkpoint every N layers (default: 1)
Build a single Mamba block.
The Mamba block consists of:
- Two parallel branches after input projection
- One branch: Conv1D -> SiLU -> Selective SSM
- Other branch: Linear -> SiLU (gating)
- Multiply outputs -> Project back
Options
:hidden_size- Internal dimension D:state_size- SSM state dimension N:expand_factor- Expansion factor E:conv_size- Convolution kernel size:name- Layer name prefix
Build the Selective State Space Model (S6).
This is the core of Mamba: an SSM where the A, B, C parameters are computed from the input, making it "selective".
The SSM equations:
- h(t) = exp(delta A) h(t-1) + delta B x(t)
- y(t) = C * h(t)
Where delta, B, C are input-dependent projections.
Initialize hidden state for incremental inference.
Returns a map containing the cached state for each layer. For each layer, we cache:
:h- The SSM hidden state [batch, state_size]:conv_buffer- Buffer for causal convolution [batch, conv_size-1, inner_size]
Options
:batch_size- Batch size (default: 1):hidden_size- Hidden dimension D (default: 256):state_size- SSM state dimension N (default: 16):expand_factor- Expansion factor E (default: 2):conv_size- Convolution kernel size (default: 4):num_layers- Number of Mamba blocks (default: 2)
Example
cache = GatedSSM.init_cache(batch_size: 1, hidden_size: 256)
{output, new_cache} = GatedSSM.step(x_single_frame, params, cache, opts)
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Mamba model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Mamba model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for real-time sequence processing (60fps).
@spec step(Nx.Tensor.t(), map(), map(), keyword()) :: {Nx.Tensor.t(), map()}
Perform a single incremental step with cached state.
Takes a single frame input and the current cache, returns the output and updated cache. This enables O(1) inference per frame instead of O(window_size).
Arguments
x- Single frame input [batch, hidden_size] or [batch, 1, hidden_size]params- Model parameters (from trained model)cache- Cache frominit_cache/1or previousstep/4call
Returns
{output, new_cache} where:
output- [batch, hidden_size] tensornew_cache- Updated cache for next step
Example
cache = GatedSSM.init_cache(hidden_size: 256)
{out1, cache} = GatedSSM.step(frame1, params, cache)
{out2, cache} = GatedSSM.step(frame2, params, cache)
# out2 is equivalent to running [frame1, frame2] through full model