Edifice.SSM.GatedSSM (Edifice v0.2.0)

Copy Markdown View Source

GatedSSM: Simplified gated temporal network inspired by state space models.

NOTE: This is NOT a true Mamba implementation. It uses a simplified gating mechanism instead of the parallel associative scan that makes Mamba efficient. For true Mamba, see Edifice.SSM.Mamba.

This module provides competitive results and is numerically stable. Use it when you want a lightweight temporal model that's simpler than true Mamba.

How It Differs From True Mamba

AspectTrue MambaGatedSSM
Core algorithmParallel associative scanGated multiplication
Recurrenceh(t) = Ah(t-1) + BxSigmoid gating approximation
ConvolutionLearned depthwise separableMean pooling + projection
ComplexityO(L) parallelO(L) sequential approximation

Architecture

Input [batch, seq_len, embed_dim]
      
      

         GatedSSM Block              
                                      
   Linear (expand)         
                                   
     MeanPool + SiLU                
                                   
     Gated Context     Linear+SiLU   
                                   
   multiply         
                                     
         Linear (project)             

      
       (repeat for num_layers)
      
      
[batch, seq_len, embed_dim] -> last timestep -> [batch, embed_dim]

Usage

# Build GatedSSM backbone
model = GatedSSM.build(
  embed_dim: 256,
  hidden_size: 256,
  state_size: 16,
  num_layers: 2,
  expand_factor: 2
)

When To Use

  • Lightweight temporal processing without full Mamba complexity
  • Stable training (no NaN issues observed)
  • When true Mamba isn't available or needed

Summary

Types

Options for build/1.

Functions

Build a Mamba model for sequence processing.

Build a causal 1D convolution layer.

Build a Mamba model with gradient checkpointing for memory-efficient training.

Build a single Mamba block.

Build the Selective State Space Model (S6).

Initialize hidden state for incremental inference.

Get the output size of a Mamba model.

Calculate approximate parameter count for a Mamba model.

Get recommended defaults for real-time sequence processing (60fps).

Perform a single incremental step with cached state.

Types

build_opt()

@type build_opt() ::
  {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Mamba model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension D (default: 256)
  • :state_size - SSM state dimension N (default: 16)
  • :expand_factor - Expansion factor E for inner dim (default: 2)
  • :conv_size - 1D convolution kernel size (default: 4)
  • :num_layers - Number of Mamba blocks (default: 2)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_causal_conv1d(input, channels, kernel_size, name)

@spec build_causal_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()

Build a causal 1D convolution layer.

Applies convolution only over past timesteps (causal padding). Uses a simplified approach with sliding window mean + learned projection.

build_checkpointed(opts \\ [])

@spec build_checkpointed(keyword()) :: Axon.t()

Build a Mamba model with gradient checkpointing for memory-efficient training.

Same as build/1 but applies gradient checkpointing to each Mamba block, reducing memory usage at the cost of ~30% more compute.

Memory Savings

For a 3-layer Mamba with window_size=60, batch_size=256:

  • Without checkpointing: ~2.5GB activation memory
  • With checkpointing: ~0.8GB activation memory

When to Use

  • Training on GPU with limited VRAM
  • Using large batch sizes or long sequences
  • When you're hitting OOM during training

Options

Same as build/1, plus:

  • :checkpoint_every - Checkpoint every N layers (default: 1)

build_mamba_block(input, opts \\ [])

@spec build_mamba_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single Mamba block.

The Mamba block consists of:

  1. Two parallel branches after input projection
  2. One branch: Conv1D -> SiLU -> Selective SSM
  3. Other branch: Linear -> SiLU (gating)
  4. Multiply outputs -> Project back

Options

  • :hidden_size - Internal dimension D
  • :state_size - SSM state dimension N
  • :expand_factor - Expansion factor E
  • :conv_size - Convolution kernel size
  • :name - Layer name prefix

build_selective_ssm(input, opts \\ [])

@spec build_selective_ssm(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Selective State Space Model (S6).

This is the core of Mamba: an SSM where the A, B, C parameters are computed from the input, making it "selective".

The SSM equations:

  • h(t) = exp(delta A) h(t-1) + delta B x(t)
  • y(t) = C * h(t)

Where delta, B, C are input-dependent projections.

init_cache(opts \\ [])

@spec init_cache(keyword()) :: map()

Initialize hidden state for incremental inference.

Returns a map containing the cached state for each layer. For each layer, we cache:

  • :h - The SSM hidden state [batch, state_size]
  • :conv_buffer - Buffer for causal convolution [batch, conv_size-1, inner_size]

Options

  • :batch_size - Batch size (default: 1)
  • :hidden_size - Hidden dimension D (default: 256)
  • :state_size - SSM state dimension N (default: 16)
  • :expand_factor - Expansion factor E (default: 2)
  • :conv_size - Convolution kernel size (default: 4)
  • :num_layers - Number of Mamba blocks (default: 2)

Example

cache = GatedSSM.init_cache(batch_size: 1, hidden_size: 256)
{output, new_cache} = GatedSSM.step(x_single_frame, params, cache, opts)

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Mamba model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Mamba model.

step(x, params, cache, opts \\ [])

@spec step(Nx.Tensor.t(), map(), map(), keyword()) :: {Nx.Tensor.t(), map()}

Perform a single incremental step with cached state.

Takes a single frame input and the current cache, returns the output and updated cache. This enables O(1) inference per frame instead of O(window_size).

Arguments

  • x - Single frame input [batch, hidden_size] or [batch, 1, hidden_size]
  • params - Model parameters (from trained model)
  • cache - Cache from init_cache/1 or previous step/4 call

Returns

{output, new_cache} where:

  • output - [batch, hidden_size] tensor
  • new_cache - Updated cache for next step

Example

cache = GatedSSM.init_cache(hidden_size: 256)
{out1, cache} = GatedSSM.step(frame1, params, cache)
{out2, cache} = GatedSSM.step(frame2, params, cache)
# out2 is equivalent to running [frame1, frame2] through full model