SoundStorm: Efficient Parallel Audio Generation via masked prediction.
SoundStorm generates audio non-autoregressively by iteratively refining masked neural codec tokens (e.g. EnCodec). Instead of predicting one token at a time left-to-right, SoundStorm masks out low-confidence tokens and predicts all of them in parallel, repeating for T refinement steps. A Conformer backbone operates on the flattened codec token sequence.
Motivation
Autoregressive audio generation (WaveNet, AudioLM) is slow because each token depends on all previous tokens. SoundStorm achieves ~100x speedup by predicting entire codebook layers in parallel. The coarse-to-fine hierarchy from neural codecs (codebook 0 = coarse, 1-7 = fine detail) allows progressive refinement.
Architecture
Input: codec tokens [batch, num_codebooks, seq_len]
(codebook 0 may be provided as conditioning)
|
Flatten to [batch, num_codebooks * seq_len, hidden_dim]
|
+------------------------------------------+
| Conformer Backbone (num_layers) |
| - Self-attention (bidirectional) |
| - Convolution module |
| - FFN with Macaron structure |
+------------------------------------------+
|
Project to [batch, num_codebooks * seq_len, codebook_size]
|
Unflatten to [batch, num_codebooks, seq_len, codebook_size]
|
Apply mask: only predict masked positions
|
Cosine schedule: mask_ratio decreases over T stepsIterative Refinement (Inference)
- Start with codebook 0 as conditioning (from text or audio prompt)
- Initialize codebooks 1-7 with [MASK] tokens
- For step t = 1..T:
- Forward pass: get logits for all positions
- Compute confidence (max prob) for masked positions
- Unmask top-k% most confident predictions (cosine schedule)
- Return final tokens
Usage
model = SoundStorm.build(
num_codebooks: 8,
codebook_size: 1024,
hidden_dim: 512,
num_layers: 12
)
# One refinement step
new_tokens = SoundStorm.soundstorm_step(model, params, tokens, mask, step: 5, total_steps: 16)
# Full generation
final_tokens = SoundStorm.generate(model, params, conditioning_tokens, num_steps: 16)References
- Borsos et al., "SoundStorm: Efficient Parallel Audio Generation" (Google, 2023) — https://arxiv.org/abs/2305.09636
- AudioLM: https://arxiv.org/abs/2209.03143
- EnCodec: https://arxiv.org/abs/2210.13438
Summary
Functions
Build a SoundStorm model.
Full SoundStorm generation loop.
Get output size (codebook_size for per-position predictions).
Perform one SoundStorm refinement step.
Types
@type build_opt() :: {:num_codebooks, pos_integer()} | {:codebook_size, pos_integer()} | {:hidden_dim, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:conv_kernel_size, pos_integer()} | {:dropout, float()}
Options for build/1.
Functions
Build a SoundStorm model.
Options
:num_codebooks- Number of codec codebooks (default: 8):codebook_size- Vocabulary size per codebook (default: 1024):hidden_dim- Conformer hidden dimension (default: 512):num_layers- Number of Conformer blocks (default: 12):num_heads- Number of attention heads (default: 8):conv_kernel_size- Depthwise conv kernel size (default: 31):dropout- Dropout rate (default: 0.1)
Returns
An Axon model that takes flattened codec tokens [batch, num_codebooks * seq_len]
and outputs logits [batch, num_codebooks * seq_len, codebook_size].
@spec generate( (map(), map() -> Nx.Tensor.t()), map(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Full SoundStorm generation loop.
Starting from conditioning tokens (codebook 0), generates codebooks 1-7 through iterative refinement.
Parameters
predict_fn- Compiled prediction functionparams- Model parametersconditioning_tokens- Codebook 0 tokens[batch, seq_len]
Options
:num_steps- Number of refinement steps (default: 16):num_codebooks- Number of codebooks (default: 8):mask_token- Token ID used for masking (default: 0)
Returns
Generated tokens [batch, num_codebooks, seq_len].
@spec output_size(keyword()) :: pos_integer()
Get output size (codebook_size for per-position predictions).
@spec soundstorm_step( (map(), map() -> Nx.Tensor.t()), map(), Nx.Tensor.t(), Nx.Tensor.t(), pos_integer(), pos_integer() ) :: Nx.Tensor.t()
Perform one SoundStorm refinement step.
Given current tokens and a mask indicating which positions to predict, runs the model and selectively unmasks the most confident predictions according to the cosine schedule.
Parameters
predict_fn- Compiled prediction function fromAxon.build/2params- Model parameterstokens- Current token tensor[batch, num_codebooks * seq_len]mask- Boolean mask[batch, num_codebooks * seq_len], true = predictstep- Current refinement step (1-indexed)total_steps- Total number of refinement steps
Returns
Updated tokens tensor with some positions unmasked.