H3: Hungry Hungry Hippos.
Implements the H3 architecture from "Hungry Hungry Hippos: Towards Language Modeling with State Space Models" (Fu et al., ICLR 2023). H3 combines two SSM layers with a short convolution and multiplicative gating to close the gap between SSMs and Transformers on language modeling.
Key Innovation: Two-SSM + Short Conv
H3 interleaves two types of SSMs with multiplicative interaction:
Branch 1 (Shift SSM): Captures local dependencies via diagonal SSM
Branch 2 (Diag SSM): Captures broader patterns via diagonal SSM
Short Conv: Models very local (1-4 token) patterns
Output = ShortConv(ShiftSSM(x) * DiagSSM(x))Architecture
Input [batch, seq_len, embed_dim]
|
v
+-----------------------+
| Input Projection |
+-----------------------+
|
v
+-----------------------+
| H3 Block x N |
| +-- ShiftSSM(x) --+ |
| | | |
| +-- DiagSSM(x) ---+ |
| | | |
| +--- multiply ----+ |
| | |
| ShortConv + OutProj |
| Residual + FFN |
+-----------------------+
|
v
[batch, hidden_size] (last timestep)Usage
model = H3.build(
embed_dim: 287,
hidden_size: 256,
state_size: 64,
conv_size: 4,
num_layers: 4
)Reference
- Paper: "Hungry Hungry Hippos: Towards Language Modeling with State Space Models"
- arXiv: https://arxiv.org/abs/2212.14052
Summary
Functions
Build an H3 model for sequence processing.
Build a single H3 block: two SSMs multiplied + short conv + FFN.
Get the output size of an H3 model.
Calculate approximate parameter count for an H3 model.
Get recommended defaults.
Types
@type build_opt() :: {:conv_size, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:state_size, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build an H3 model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):state_size- SSM state dimension N (default: 64):conv_size- Short convolution kernel size (default: 4):num_layers- Number of H3 blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build a single H3 block: two SSMs multiplied + short conv + FFN.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of an H3 model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for an H3 model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.