Edifice.Audio.SoundStorm (Edifice v0.2.0)

Copy Markdown View Source

SoundStorm: Efficient Parallel Audio Generation via masked prediction.

SoundStorm generates audio non-autoregressively by iteratively refining masked neural codec tokens (e.g. EnCodec). Instead of predicting one token at a time left-to-right, SoundStorm masks out low-confidence tokens and predicts all of them in parallel, repeating for T refinement steps. A Conformer backbone operates on the flattened codec token sequence.

Motivation

Autoregressive audio generation (WaveNet, AudioLM) is slow because each token depends on all previous tokens. SoundStorm achieves ~100x speedup by predicting entire codebook layers in parallel. The coarse-to-fine hierarchy from neural codecs (codebook 0 = coarse, 1-7 = fine detail) allows progressive refinement.

Architecture

Input: codec tokens [batch, num_codebooks, seq_len]
       (codebook 0 may be provided as conditioning)
      |
Flatten to [batch, num_codebooks * seq_len, hidden_dim]
      |
+------------------------------------------+
| Conformer Backbone (num_layers)          |
|   - Self-attention (bidirectional)       |
|   - Convolution module                   |
|   - FFN with Macaron structure           |
+------------------------------------------+
      |
Project to [batch, num_codebooks * seq_len, codebook_size]
      |
Unflatten to [batch, num_codebooks, seq_len, codebook_size]
      |
Apply mask: only predict masked positions
      |
Cosine schedule: mask_ratio decreases over T steps

Iterative Refinement (Inference)

  1. Start with codebook 0 as conditioning (from text or audio prompt)
  2. Initialize codebooks 1-7 with [MASK] tokens
  3. For step t = 1..T:
    • Forward pass: get logits for all positions
    • Compute confidence (max prob) for masked positions
    • Unmask top-k% most confident predictions (cosine schedule)
  4. Return final tokens

Usage

model = SoundStorm.build(
  num_codebooks: 8,
  codebook_size: 1024,
  hidden_dim: 512,
  num_layers: 12
)

# One refinement step
new_tokens = SoundStorm.soundstorm_step(model, params, tokens, mask, step: 5, total_steps: 16)

# Full generation
final_tokens = SoundStorm.generate(model, params, conditioning_tokens, num_steps: 16)

References

Summary

Types

Options for build/1.

Functions

Build a SoundStorm model.

Get output size (codebook_size for per-position predictions).

Types

build_opt()

@type build_opt() ::
  {:num_codebooks, pos_integer()}
  | {:codebook_size, pos_integer()}
  | {:hidden_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:conv_kernel_size, pos_integer()}
  | {:dropout, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a SoundStorm model.

Options

  • :num_codebooks - Number of codec codebooks (default: 8)
  • :codebook_size - Vocabulary size per codebook (default: 1024)
  • :hidden_dim - Conformer hidden dimension (default: 512)
  • :num_layers - Number of Conformer blocks (default: 12)
  • :num_heads - Number of attention heads (default: 8)
  • :conv_kernel_size - Depthwise conv kernel size (default: 31)
  • :dropout - Dropout rate (default: 0.1)

Returns

An Axon model that takes flattened codec tokens [batch, num_codebooks * seq_len] and outputs logits [batch, num_codebooks * seq_len, codebook_size].

generate(predict_fn, params, conditioning_tokens, opts \\ [])

@spec generate(
  (map(), map() -> Nx.Tensor.t()),
  map(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Full SoundStorm generation loop.

Starting from conditioning tokens (codebook 0), generates codebooks 1-7 through iterative refinement.

Parameters

  • predict_fn - Compiled prediction function
  • params - Model parameters
  • conditioning_tokens - Codebook 0 tokens [batch, seq_len]

Options

  • :num_steps - Number of refinement steps (default: 16)
  • :num_codebooks - Number of codebooks (default: 8)
  • :mask_token - Token ID used for masking (default: 0)

Returns

Generated tokens [batch, num_codebooks, seq_len].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get output size (codebook_size for per-position predictions).

soundstorm_step(predict_fn, params, tokens, mask, step, total_steps)

@spec soundstorm_step(
  (map(), map() -> Nx.Tensor.t()),
  map(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  pos_integer(),
  pos_integer()
) :: Nx.Tensor.t()

Perform one SoundStorm refinement step.

Given current tokens and a mask indicating which positions to predict, runs the model and selectively unmasks the most confident predictions according to the cosine schedule.

Parameters

  • predict_fn - Compiled prediction function from Axon.build/2
  • params - Model parameters
  • tokens - Current token tensor [batch, num_codebooks * seq_len]
  • mask - Boolean mask [batch, num_codebooks * seq_len], true = predict
  • step - Current refinement step (1-indexed)
  • total_steps - Total number of refinement steps

Returns

Updated tokens tensor with some positions unmasked.