Edifice.SSM.GSS (Edifice v0.2.0)

Copy Markdown View Source

GSS: Gated State Space Model.

Implements the Gated State Space model from "Long Range Language Modeling via Gated State Spaces" (Mehta et al., 2023). GSS simplifies S4 by using fixed (learned but not input-dependent) A, B, C matrices combined with multiplicative gating for non-linearity.

Key Innovation: Fixed SSM + Multiplicative Gating

Unlike Mamba (where B, C, dt are input-dependent), GSS uses:

  • Fixed diagonal A, B, C matrices (learned via Axon.param)
  • Gating for input-dependent non-linearity: gate = sigmoid(W_g * x)
  • Result: simpler than Mamba, more expressive than vanilla S4

Equations

# SSM with fixed parameters:
h_t = A * h_{t-1} + B * x_t      # A, B are learned parameters (not input-dependent)
y_t = C * h_t                      # C is a learned parameter

# Gating for non-linearity:
gate_t = sigmoid(W_g * x_t + b_g)
output_t = gate_t * y_t

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|         GSS Block                    |
|  LayerNorm -> [SSM path, Gate path]  |
|    SSM: linear -> scan(A,B) -> C*h   |
|    Gate: linear -> sigmoid           |
|  output = SSM * Gate                 |
|  -> project -> residual              |
|  LayerNorm -> FFN -> residual        |
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]

Compared to Other SSMs

ModelA,B,CGatingComplexity
S4Fixed (HiPPO)NoneO(L log L)
GSSFixed (learned)MultiplicativeO(L)
MambaInput-dependentSiLUO(L)

Usage

model = GSS.build(
  embed_dim: 287,
  hidden_size: 256,
  state_size: 16,
  num_layers: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a GSS model for sequence processing.

Default dropout rate

Default hidden dimension

Default number of layers

Default SSM state dimension

Get the output size of a GSS model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a GSS model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :state_size - SSM state dimension (default: 16)
  • :num_layers - Number of GSS blocks (default: 4)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

default_state_size()

@spec default_state_size() :: pos_integer()

Default SSM state dimension

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a GSS model.