Edifice.SSM.Common (Edifice v0.2.0)

Copy Markdown View Source

Shared components for all Mamba architecture variants.

This module contains the common building blocks used across Mamba variants:

  • Default hyperparameters
  • Model structure builders (input projection, layer stacking, last timestep)
  • Block structure (normalization, projections, gating)
  • Depthwise convolution
  • SSM parameter projections
  • SSM discretization
  • Utility functions

Mamba Variants

All variants share the same architecture, differing only in the scan algorithm:

VariantScan AlgorithmNotes
MambaBlellochWork-efficient O(L) work, O(log L) depth
MambaHillisSteeleHillis-SteeleO(L log L) work, more parallelism
MambaCumsumCumsum-basedExperimental log-space approach
MambaSSDSSD chunkedMamba-2's matmul approach

See Also

Summary

Functions

Blelloch parallel scan (work-efficient O(L) work, O(log L) depth).

Build the common Mamba block structure.

Build a depthwise separable 1D convolution layer.

Build the common Mamba model structure.

Build the SSM parameter projections (B, C, dt).

Compute the SSM output from hidden states.

Default convolution kernel size

Default dropout rate

Default expansion factor E

Default hidden dimension D

Default number of Mamba blocks

Default SSM state dimension N

Discretize the SSM parameters for the scan.

Maximum delta

Minimum delta for numerical stability

Get the output size of a Mamba model.

Calculate approximate parameter count for a Mamba model.

Get recommended defaults for real-time sequence processing (60fps).

Sequential scan for short sequences or fallback.

Functions

blelloch_scan(a, b)

@spec blelloch_scan(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Blelloch parallel scan (work-efficient O(L) work, O(log L) depth).

Uses Enum.reduce for the loop - this lets XLA JIT each level efficiently.

Parameters

  • a - Decay factors [batch, seq_len, hidden_size, state_size]
  • b - Input contributions [batch, seq_len, hidden_size, state_size]

Returns

Hidden states [batch, seq_len, hidden_size, state_size]

build_block(input, opts, ssm_builder)

@spec build_block(Axon.t(), keyword(), (Axon.t(), keyword() -> Axon.t())) :: Axon.t()

Build the common Mamba block structure.

This handles everything except the SSM scan itself:

  • Layer normalization
  • Input projection (to 2x inner_size for x/z branches)
  • X/Z branch splitting
  • Depthwise convolution + SiLU on X branch
  • SiLU gating on Z branch
  • Gated multiplication
  • Output projection

The caller provides an ssm_builder function that constructs the SSM layer.

Parameters

  • input - Input Axon node
  • opts - Block options (hidden_size, state_size, expand_factor, conv_size, name)
  • ssm_builder - Function (x_activated, ssm_opts) -> Axon.t() that builds SSM

Returns

An Axon node representing the block output.

build_depthwise_conv1d(input, channels, kernel_size, name)

@spec build_depthwise_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()

Build a depthwise separable 1D convolution layer.

True Mamba uses learned depthwise convolution, not mean pooling. This approximates depthwise conv behavior for SSM input processing.

Parameters

  • input - Input Axon node [batch, seq_len, channels]
  • channels - Number of output channels
  • kernel_size - Convolution kernel size
  • name - Layer name prefix

build_model(opts, block_builder)

@spec build_model(
  keyword(),
  (Axon.t(), keyword() -> Axon.t())
) :: Axon.t()

Build the common Mamba model structure.

This handles:

  • Input projection (if embed_dim != hidden_size)
  • Layer stacking with residual connections and dropout
  • Last timestep extraction

The caller provides a block_builder function that constructs each Mamba block.

Parameters

  • opts - Model options (embed_dim, hidden_size, num_layers, dropout, etc.)
  • block_builder - Function (input, opts) -> Axon.t() that builds one block

Returns

An Axon model that outputs [batch, hidden_size].

build_ssm_projections(input, opts)

@spec build_ssm_projections(
  Axon.t(),
  keyword()
) :: {Axon.t(), Axon.t(), Axon.t()}

Build the SSM parameter projections (B, C, dt).

These are the "selective" parameters that make Mamba input-dependent:

  • B: Input matrix [batch, seq_len, state_size]
  • C: Output matrix [batch, seq_len, state_size]
  • dt: Discretization step [batch, seq_len, hidden_size]

Parameters

  • input - Input Axon node [batch, seq_len, hidden_size]
  • opts - Options (hidden_size, state_size, dt_rank, name)

Returns

Tuple of {b_matrix, c_matrix, dt_proj} Axon nodes.

compute_ssm_output(h, c)

@spec compute_ssm_output(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the SSM output from hidden states.

y[t] = C[t] * h[t]

Parameters

  • h - Hidden states [batch, seq_len, hidden_size, state_size]
  • c - C matrix [batch, seq_len, state_size]

Returns

Output tensor [batch, seq_len, hidden_size]

default_conv_size()

@spec default_conv_size() :: pos_integer()

Default convolution kernel size

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_expand_factor()

@spec default_expand_factor() :: pos_integer()

Default expansion factor E

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension D

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of Mamba blocks

default_state_size()

@spec default_state_size() :: pos_integer()

Default SSM state dimension N

discretize_ssm(x, b, dt, state_size)

@spec discretize_ssm(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), pos_integer()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Discretize the SSM parameters for the scan.

Converts continuous-time SSM to discrete-time:

  • A_bar = exp(Δ * A)
  • B_bar = Δ * B
  • Bx = B_bar * x

Parameters

  • x - Input tensor [batch, seq_len, hidden_size]
  • b - B matrix [batch, seq_len, state_size]
  • dt - Delta tensor [batch, seq_len, hidden_size]
  • state_size - SSM state dimension

Returns

Tuple of {a_bar, bx} where:

  • a_bar: [batch, seq_len, hidden_size, state_size] - decay factors
  • bx: [batch, seq_len, hidden_size, state_size] - input contributions

dt_max()

@spec dt_max() :: float()

Maximum delta

dt_min()

@spec dt_min() :: float()

Minimum delta for numerical stability

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Mamba model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a Mamba model.

sequential_scan(a, b)

@spec sequential_scan(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Sequential scan for short sequences or fallback.

Computes h[t] = a[t] * h[t-1] + b[t] for all t.

Parameters

  • a - Decay factors [batch, seq_len, hidden_size, state_size]
  • b - Input contributions [batch, seq_len, hidden_size, state_size]

Returns

Hidden states [batch, seq_len, hidden_size, state_size]