Edifice.SSM.S4 (Edifice v0.2.0)

Copy Markdown View Source

S4: Structured State Spaces for Sequences.

Implements the S4 architecture from "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al., ICLR 2022). S4 introduced the key idea of using HiPPO-initialized diagonal state matrices for stable long-range sequence modeling.

Key Innovation: HiPPO Initialization

The state matrix A is initialized using the HiPPO framework, which produces matrices that optimally compress continuous signals into finite-dimensional state. The diagonal parameterization enables efficient parallel computation:

Continuous SSM:
  x'(t) = A x(t) + B u(t)
  y(t)  = C x(t) + D u(t)

Discretization (ZOH):
  A_bar = exp(dt * A)
  B_bar = (A_bar - I) * A^{-1} * B   (simplified to dt * B for diagonal A)
  h[t]  = A_bar * h[t-1] + B_bar * u[t]
  y[t]  = C * h[t]

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-----------------------+
| Input Projection      |
+-----------------------+
      |
      v
+-----------------------+
| S4 Block x N          |
|  LayerNorm            |
|  SSM (HiPPO A)        |
|  Dropout + Residual   |
|  FFN Block            |
+-----------------------+
      |
      v
+-----------------------+
| Final LayerNorm       |
+-----------------------+
      |
      v
[batch, hidden_size]    (last timestep)

Usage

model = S4.build(
  embed_dim: 287,
  hidden_size: 256,
  state_size: 64,
  num_layers: 4
)

Reference

Summary

Types

Options for build/1.

Functions

Build an S4 model for sequence processing.

Build a single S4 block (LayerNorm -> SSM -> dropout -> residual + FFN).

Get the output size of an S4 model.

Calculate approximate parameter count for an S4 model.

Get recommended defaults for real-time sequence processing.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an S4 model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :state_size - SSM state dimension N (default: 64)
  • :num_layers - Number of S4 blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_s4_block(input, opts)

@spec build_s4_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single S4 block (LayerNorm -> SSM -> dropout -> residual + FFN).

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an S4 model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an S4 model.