State Space Transformer — parallel SSM + attention with learned gating per block.
Combines a selective state space model (SSM) path with a multi-head causal attention path in every block, fused via a learned sigmoid gate. This allows the model to dynamically balance local/recurrent processing (SSM) with global attention at each layer.
Architecture
Input [batch, seq_len, embed_dim]
|
Per block:
Pre-norm -> SSM path (selective scan with gating)
-> Attention path (multi-head causal)
-> gate * ssm_out + (1-gate) * attn_out
-> FFN + residual
|
Final norm -> last timestep -> [batch, hidden_size]Usage
model = SSTransformer.build(
embed_dim: 256,
hidden_size: 256,
state_size: 16,
num_layers: 6,
num_heads: 4
)References
- Dao & Gu, "Transformers are SSMs" (2024) — Mamba-2
- NVIDIA, "Hymba: A Hybrid-head Architecture" (2024) — parallel gating
Summary
Functions
Build a State Space Transformer model.
Get the output size of the model.
Get recommended defaults.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:state_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:head_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:conv_size, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a State Space Transformer model.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Internal hidden dimension (default: 256):state_size- SSM state dimension (default: 16):num_layers- Number of hybrid blocks (default: 6):num_heads- Number of attention heads (default: 4):head_dim- Dimension per attention head (default: 64):expand_factor- SSM expansion factor (default: 2):conv_size- Causal convolution kernel size (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length (default: 60)
Returns
An Axon model outputting [batch, hidden_size].
@spec output_size(keyword()) :: pos_integer()
Get the output size of the model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.