Edifice.SSM.Hymba (Edifice v0.2.0)

Copy Markdown View Source

Hymba: Hybrid-head Architecture with Parallel Mamba + Attention.

Implements the Hymba architecture from "Hymba: A Hybrid-head Architecture for Small Language Models" (NVIDIA, 2024). Unlike sequential hybrid models (Jamba, Zamba), Hymba runs Mamba and attention in parallel within each block, with learnable gated fusion.

Key Innovations

  1. Parallel Mamba + Attention: Both paths process the same input simultaneously, and outputs are combined via a learnable gate: output = gate * mamba_out + (1 - gate) * attn_out

  2. Learnable Meta Tokens: K learnable vectors prepended to K/V in the attention path. These serve as "summarizers" that compress global context, reducing the effective attention complexity while maintaining long-range access.

  3. Cross-layer meta token propagation: Meta token states are updated across layers, accumulating information throughout the network.

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|         Hymba Block                  |
|                                      |
|  +--------+    +------------------+  |
|  | Mamba   |    | Attention       |  |
|  | (SSM)   |    | + Meta Tokens   |  |
|  +----+----+    +--------+--------+  |
|       |                  |           |
|       v                  v           |
|  gate * mamba + (1-gate) * attn      |
|            |                         |
|            v                         |
|       residual + FFN                 |
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]

Compared to Other Hybrids

ModelMamba + AttentionPattern
JambaAlternatingSequential layers
ZambaShared attentionInterleaved
HymbaParallel headsWithin each block

Usage

model = Hymba.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  num_meta_tokens: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a Hymba model for sequence processing.

Default dropout rate

Default hidden dimension

Default number of attention heads

Default number of layers

Default number of learnable meta tokens

Default SSM state dimension

Get the output size of a Hymba model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_meta_tokens, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Hymba model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :state_size - SSM state dimension (default: 16)
  • :num_layers - Number of Hymba blocks (default: 4)
  • :num_heads - Number of attention heads (default: 4)
  • :num_meta_tokens - Learnable meta tokens for attention (default: 4)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_heads()

@spec default_num_heads() :: pos_integer()

Default number of attention heads

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

default_num_meta_tokens()

@spec default_num_meta_tokens() :: pos_integer()

Default number of learnable meta tokens

default_state_size()

@spec default_state_size() :: pos_integer()

Default SSM state dimension

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Hymba model.