Edifice.SSM.H3 (Edifice v0.2.0)

Copy Markdown View Source

H3: Hungry Hungry Hippos.

Implements the H3 architecture from "Hungry Hungry Hippos: Towards Language Modeling with State Space Models" (Fu et al., ICLR 2023). H3 combines two SSM layers with a short convolution and multiplicative gating to close the gap between SSMs and Transformers on language modeling.

Key Innovation: Two-SSM + Short Conv

H3 interleaves two types of SSMs with multiplicative interaction:

Branch 1 (Shift SSM): Captures local dependencies via diagonal SSM
Branch 2 (Diag SSM):  Captures broader patterns via diagonal SSM
Short Conv:           Models very local (1-4 token) patterns

Output = ShortConv(ShiftSSM(x) * DiagSSM(x))

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-----------------------+
| Input Projection      |
+-----------------------+
      |
      v
+-----------------------+
| H3 Block x N          |
|  +-- ShiftSSM(x) --+  |
|  |                  |  |
|  +-- DiagSSM(x) ---+  |
|  |                  |  |
|  +--- multiply ----+  |
|  |                     |
|  ShortConv + OutProj   |
|  Residual + FFN        |
+-----------------------+
      |
      v
[batch, hidden_size]    (last timestep)

Usage

model = H3.build(
  embed_dim: 287,
  hidden_size: 256,
  state_size: 64,
  conv_size: 4,
  num_layers: 4
)

Reference

Summary

Types

Options for build/1.

Functions

Build an H3 model for sequence processing.

Build a single H3 block: two SSMs multiplied + short conv + FFN.

Get the output size of an H3 model.

Calculate approximate parameter count for an H3 model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an H3 model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :state_size - SSM state dimension N (default: 64)
  • :conv_size - Short convolution kernel size (default: 4)
  • :num_layers - Number of H3 blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_h3_block(input, opts)

@spec build_h3_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single H3 block: two SSMs multiplied + short conv + FFN.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an H3 model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an H3 model.