# `Edifice.Audio.SoundStorm`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/audio/soundstorm.ex#L1)

SoundStorm: Efficient Parallel Audio Generation via masked prediction.

SoundStorm generates audio *non-autoregressively* by iteratively refining
masked neural codec tokens (e.g. EnCodec). Instead of predicting one token
at a time left-to-right, SoundStorm masks out low-confidence tokens and
predicts all of them in parallel, repeating for T refinement steps.
A Conformer backbone operates on the flattened codec token sequence.

## Motivation

Autoregressive audio generation (WaveNet, AudioLM) is slow because each
token depends on all previous tokens. SoundStorm achieves ~100x speedup
by predicting entire codebook layers in parallel. The coarse-to-fine
hierarchy from neural codecs (codebook 0 = coarse, 1-7 = fine detail)
allows progressive refinement.

## Architecture

```
Input: codec tokens [batch, num_codebooks, seq_len]
       (codebook 0 may be provided as conditioning)
      |
Flatten to [batch, num_codebooks * seq_len, hidden_dim]
      |
+------------------------------------------+
| Conformer Backbone (num_layers)          |
|   - Self-attention (bidirectional)       |
|   - Convolution module                   |
|   - FFN with Macaron structure           |
+------------------------------------------+
      |
Project to [batch, num_codebooks * seq_len, codebook_size]
      |
Unflatten to [batch, num_codebooks, seq_len, codebook_size]
      |
Apply mask: only predict masked positions
      |
Cosine schedule: mask_ratio decreases over T steps
```

## Iterative Refinement (Inference)

1. Start with codebook 0 as conditioning (from text or audio prompt)
2. Initialize codebooks 1-7 with [MASK] tokens
3. For step t = 1..T:
   - Forward pass: get logits for all positions
   - Compute confidence (max prob) for masked positions
   - Unmask top-k% most confident predictions (cosine schedule)
4. Return final tokens

## Usage

    model = SoundStorm.build(
      num_codebooks: 8,
      codebook_size: 1024,
      hidden_dim: 512,
      num_layers: 12
    )

    # One refinement step
    new_tokens = SoundStorm.soundstorm_step(model, params, tokens, mask, step: 5, total_steps: 16)

    # Full generation
    final_tokens = SoundStorm.generate(model, params, conditioning_tokens, num_steps: 16)

## References

- Borsos et al., "SoundStorm: Efficient Parallel Audio Generation"
  (Google, 2023) — https://arxiv.org/abs/2305.09636
- AudioLM: https://arxiv.org/abs/2209.03143
- EnCodec: https://arxiv.org/abs/2210.13438

# `build_opt`

```elixir
@type build_opt() ::
  {:num_codebooks, pos_integer()}
  | {:codebook_size, pos_integer()}
  | {:hidden_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:conv_kernel_size, pos_integer()}
  | {:dropout, float()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a SoundStorm model.

## Options

  - `:num_codebooks` - Number of codec codebooks (default: 8)
  - `:codebook_size` - Vocabulary size per codebook (default: 1024)
  - `:hidden_dim` - Conformer hidden dimension (default: 512)
  - `:num_layers` - Number of Conformer blocks (default: 12)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:conv_kernel_size` - Depthwise conv kernel size (default: 31)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns

  An Axon model that takes flattened codec tokens `[batch, num_codebooks * seq_len]`
  and outputs logits `[batch, num_codebooks * seq_len, codebook_size]`.

# `generate`

```elixir
@spec generate(
  (map(), map() -&gt; Nx.Tensor.t()),
  map(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()
```

Full SoundStorm generation loop.

Starting from conditioning tokens (codebook 0), generates codebooks 1-7
through iterative refinement.

## Parameters

  - `predict_fn` - Compiled prediction function
  - `params` - Model parameters
  - `conditioning_tokens` - Codebook 0 tokens `[batch, seq_len]`

## Options

  - `:num_steps` - Number of refinement steps (default: 16)
  - `:num_codebooks` - Number of codebooks (default: 8)
  - `:mask_token` - Token ID used for masking (default: 0)

## Returns

  Generated tokens `[batch, num_codebooks, seq_len]`.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get output size (codebook_size for per-position predictions).

# `soundstorm_step`

```elixir
@spec soundstorm_step(
  (map(), map() -&gt; Nx.Tensor.t()),
  map(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  pos_integer(),
  pos_integer()
) :: Nx.Tensor.t()
```

Perform one SoundStorm refinement step.

Given current tokens and a mask indicating which positions to predict,
runs the model and selectively unmasks the most confident predictions
according to the cosine schedule.

## Parameters

  - `predict_fn` - Compiled prediction function from `Axon.build/2`
  - `params` - Model parameters
  - `tokens` - Current token tensor `[batch, num_codebooks * seq_len]`
  - `mask` - Boolean mask `[batch, num_codebooks * seq_len]`, true = predict
  - `step` - Current refinement step (1-indexed)
  - `total_steps` - Total number of refinement steps

## Returns

  Updated tokens tensor with some positions unmasked.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
