Edifice.SSM.S5 (Edifice v0.2.0)

Copy Markdown View Source

S5: Simplified State Space Sequence model.

S5 uses a single multi-input, multi-output (MIMO) state space model instead of the many independent SISO systems used in Mamba. This results in a simpler architecture while maintaining strong performance.

Key Innovation: MIMO SSM

Instead of having many parallel single-input single-output SSMs (like Mamba), S5 uses one large MIMO SSM:

Mamba: D separate SSMs, each with state size N
S5: 1 large SSM with D*N combined state

Architecture

Input [batch, seq_len, embed_dim]
      
      

  S5 Block                            
                                      
  Linear projection  Encoder         
                                      
   MIMO SSM    
                                    
    x'(t) = Ax(t) + Bu(t)           
    y(t) = Cx(t) + Du(t)            
                                    
    (Diagonal A for efficiency)     
                                    
     
                                      
  Decoder  Linear projection         
                                      

       (repeat for num_layers)
      
[batch, hidden_size]

Complexity

AspectValue
TrainingO(L log L) via FFT or O(L) via scan
InferenceO(1) per step
ParametersFewer than Mamba

Key Difference from Mamba

AspectS5Mamba
SSM structureMIMOMany SISOs
Input-dependenceFixed A, B, CSelective (input-dependent)
ComplexitySimplerMore complex
GatingOptionalSiLU gating

Usage

model = S5.build(
  embed_dim: 287,
  hidden_size: 256,
  state_size: 64,
  num_layers: 4
)

Use Case

S5 is useful for ablation studies to understand what Mamba's added complexity (selective mechanism, gating) contributes.

Reference

  • Paper: "Simplified State Space Layers for Sequence Modeling" (ICLR 2023)
  • arXiv: 2208.04933

Summary

Types

Options for build/1.

Functions

Build an S5 model for sequence processing.

Build the Feed-Forward Network layer.

Build the MIMO SSM layer.

Build a single S5 block.

Initialize hidden state for O(1) incremental inference.

Get the output size of an S5 model.

Calculate approximate parameter count for an S5 model.

Get recommended defaults for real-time sequence processing (60fps).

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an S5 model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension D (default: 256)
  • :state_size - SSM state dimension N (default: 64)
  • :num_layers - Number of S5 blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_ffn(input, opts)

@spec build_ffn(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the Feed-Forward Network layer.

build_mimo_ssm(input, opts)

@spec build_mimo_ssm(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the MIMO SSM layer.

Key components:

  1. Encoder projection
  2. State space model (diagonal A for efficiency)
  3. Decoder projection

build_s5_block(input, opts)

@spec build_s5_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single S5 block.

Each block has:

  1. MIMO SSM layer
  2. Feed-forward network

init_cache(opts \\ [])

@spec init_cache(keyword()) :: map()

Initialize hidden state for O(1) incremental inference.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an S5 model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an S5 model.