Hymba: Hybrid-head Architecture with Parallel Mamba + Attention.
Implements the Hymba architecture from "Hymba: A Hybrid-head Architecture for Small Language Models" (NVIDIA, 2024). Unlike sequential hybrid models (Jamba, Zamba), Hymba runs Mamba and attention in parallel within each block, with learnable gated fusion.
Key Innovations
Parallel Mamba + Attention: Both paths process the same input simultaneously, and outputs are combined via a learnable gate:
output = gate * mamba_out + (1 - gate) * attn_outLearnable Meta Tokens: K learnable vectors prepended to K/V in the attention path. These serve as "summarizers" that compress global context, reducing the effective attention complexity while maintaining long-range access.
Cross-layer meta token propagation: Meta token states are updated across layers, accumulating information throughout the network.
Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| Hymba Block |
| |
| +--------+ +------------------+ |
| | Mamba | | Attention | |
| | (SSM) | | + Meta Tokens | |
| +----+----+ +--------+--------+ |
| | | |
| v v |
| gate * mamba + (1-gate) * attn |
| | |
| v |
| residual + FFN |
+-------------------------------------+
| (repeat for num_layers)
v
Output [batch, hidden_size]Compared to Other Hybrids
| Model | Mamba + Attention | Pattern |
|---|---|---|
| Jamba | Alternating | Sequential layers |
| Zamba | Shared attention | Interleaved |
| Hymba | Parallel heads | Within each block |
Usage
model = Hymba.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
num_meta_tokens: 4
)References
- Dong et al., "Hymba: A Hybrid-head Architecture for Small Language Models" (NVIDIA, 2024)
- https://arxiv.org/abs/2411.13676
Summary
Functions
Build a Hymba model for sequence processing.
Default dropout rate
Default hidden dimension
Default number of attention heads
Default number of layers
Default number of learnable meta tokens
Default SSM state dimension
Get the output size of a Hymba model.
Get recommended defaults.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:state_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:num_meta_tokens, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Hymba model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):state_size- SSM state dimension (default: 16):num_layers- Number of Hymba blocks (default: 4):num_heads- Number of attention heads (default: 4):num_meta_tokens- Learnable meta tokens for attention (default: 4):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec default_dropout() :: float()
Default dropout rate
@spec default_num_heads() :: pos_integer()
Default number of attention heads
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec default_num_meta_tokens() :: pos_integer()
Default number of learnable meta tokens
@spec default_state_size() :: pos_integer()
Default SSM state dimension
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Hymba model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.