Edifice.SSM.SSTransformer (Edifice v0.2.0)

Copy Markdown View Source

State Space Transformer — parallel SSM + attention with learned gating per block.

Combines a selective state space model (SSM) path with a multi-head causal attention path in every block, fused via a learned sigmoid gate. This allows the model to dynamically balance local/recurrent processing (SSM) with global attention at each layer.

Architecture

Input [batch, seq_len, embed_dim]
      |
Per block:
  Pre-norm -> SSM path (selective scan with gating)
           -> Attention path (multi-head causal)
           -> gate * ssm_out + (1-gate) * attn_out
           -> FFN + residual
      |
Final norm -> last timestep -> [batch, hidden_size]

Usage

model = SSTransformer.build(
  embed_dim: 256,
  hidden_size: 256,
  state_size: 16,
  num_layers: 6,
  num_heads: 4
)

References

  • Dao & Gu, "Transformers are SSMs" (2024) — Mamba-2
  • NVIDIA, "Hymba: A Hybrid-head Architecture" (2024) — parallel gating

Summary

Types

Options for build/1.

Functions

Build a State Space Transformer model.

Get the output size of the model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a State Space Transformer model.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :state_size - SSM state dimension (default: 16)
  • :num_layers - Number of hybrid blocks (default: 6)
  • :num_heads - Number of attention heads (default: 4)
  • :head_dim - Dimension per attention head (default: 64)
  • :expand_factor - SSM expansion factor (default: 2)
  • :conv_size - Causal convolution kernel size (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model outputting [batch, hidden_size].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the model.