Edifice.Audio.EnCodec (Edifice v0.2.0)

Copy Markdown View Source

EnCodec: High-Fidelity Neural Audio Compression.

EnCodec is Meta's neural audio codec that encodes raw waveforms into discrete tokens suitable for audio language models. It uses a convolutional encoder-decoder architecture with a Residual Vector Quantizer (RVQ) to produce multiple streams of discrete tokens at different levels of detail.

Architecture

Waveform [batch, 1, samples]
      |
+-------------------+
| Encoder           |
| 1D Conv           |
| Downsample blocks |
| (stride 2, 4, 5,  |
|  8 = 320x total)  |
+-------------------+
      |
Continuous embeddings [batch, T, dim]
      |
+-------------------+
| Residual VQ       |
| Q codebooks       |
| Each quantizes    |
| the residual      |
+-------------------+
      |
Discrete tokens [batch, Q, T]
      |
(Dequantize: sum codebook vectors)
      |
+-------------------+
| Decoder           |
| Upsample blocks   |
| (stride 8, 5, 4,  |
|  2 = 320x total)  |
| 1D Conv           |
+-------------------+
      |
Reconstructed waveform [batch, 1, samples]

Residual Vector Quantization (RVQ)

RVQ uses Q codebooks in sequence. The first codebook quantizes the encoder output; each subsequent codebook quantizes the residual (error) from the previous quantization. This provides coarse-to-fine representation:

  • Codebook 0: captures overall structure (coarse)
  • Codebooks 1-7: capture progressively finer details

Bandwidth Control

Different bandwidths correspond to different numbers of active codebooks:

  • 1.5 kbps: 2 codebooks (very compressed)
  • 3.0 kbps: 4 codebooks
  • 6.0 kbps: 8 codebooks (high quality)
  • 12.0 kbps: 16 codebooks (studio quality)

Usage

# Build full EnCodec
model = EnCodec.build(
  num_codebooks: 8,
  codebook_size: 1024,
  hidden_dim: 128
)

# Encode waveform to tokens
tokens = EnCodec.encode(encoder, rvq, params, waveform)

# Decode tokens back to waveform
waveform = EnCodec.decode(decoder, rvq, params, tokens)

References

Summary

Types

Options for build/1 and related functions.

Functions

Build a complete EnCodec model (encoder + RVQ + decoder).

Build the EnCodec decoder.

Build the EnCodec encoder.

Build the Residual Vector Quantizer configuration.

Compute commitment loss for RVQ training.

Decode discrete tokens back to waveform.

Encode waveform to discrete tokens.

Get the output embedding dimension.

Dequantize RVQ tokens back to continuous embeddings.

Quantize encoder output using Residual Vector Quantization.

Compute spectral reconstruction loss (multi-scale STFT).

Types

build_opt()

@type build_opt() ::
  {:num_codebooks, pos_integer()}
  | {:codebook_size, pos_integer()}
  | {:hidden_dim, pos_integer()}
  | {:sample_rate, pos_integer()}
  | {:bandwidth, float()}

Options for build/1 and related functions.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build a complete EnCodec model (encoder + RVQ + decoder).

Options

  • :num_codebooks - Number of RVQ codebooks (default: 8)
  • :codebook_size - Vocabulary size per codebook (default: 1024)
  • :hidden_dim - Base hidden dimension (default: 128)
  • :sample_rate - Audio sample rate in Hz (default: 24000)
  • :bandwidth - Target bandwidth in kbps (default: 6.0)

Returns

A tuple {encoder, decoder} of Axon models.

  • Encoder: [batch, 1, samples] -> [batch, T, dim]
  • Decoder: [batch, T, dim] -> [batch, 1, samples]

Note: RVQ is implemented as stateful quantization functions, not as an Axon model.

build_decoder(opts \\ [])

@spec build_decoder([build_opt()]) :: Axon.t()

Build the EnCodec decoder.

The decoder mirrors the encoder with transposed convolutions for upsampling.

Options

  • :hidden_dim - Base hidden dimension (default: 128)

Returns

An Axon model: [batch, T, dim] -> [batch, 1, samples]

build_encoder(opts \\ [])

@spec build_encoder([build_opt()]) :: Axon.t()

Build the EnCodec encoder.

The encoder uses a stack of residual blocks with strided convolutions for downsampling. Each block doubles the channel count while reducing temporal resolution.

Options

  • :hidden_dim - Base hidden dimension, scaled up through layers (default: 128)

Returns

An Axon model: [batch, 1, samples] -> [batch, T, final_dim] where T = samples / 320 and final_dim = hidden_dim * 16.

build_rvq(opts \\ [])

@spec build_rvq([build_opt()]) :: map()

Build the Residual Vector Quantizer configuration.

RVQ is not an Axon model but a set of codebook parameters. This function returns the configuration; actual codebooks are initialized separately.

Options

  • :num_codebooks - Number of quantization levels (default: 8)
  • :codebook_size - Entries per codebook (default: 1024)
  • :hidden_dim - Embedding dimension (default: 128, scaled to final_dim)

Returns

A map with RVQ configuration.

commitment_loss(z_e, z_q)

@spec commitment_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute commitment loss for RVQ training.

Encourages encoder outputs to stay close to codebook entries.

Parameters

  • z_e - Encoder output [batch, T, dim]
  • z_q - Quantized output [batch, T, dim]

Returns

Commitment loss scalar.

decode(decoder_fn, params, codebooks, tokens)

@spec decode(function(), map(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Decode discrete tokens back to waveform.

Parameters

  • decoder_fn - Compiled decoder prediction function
  • params - Decoder parameters
  • codebooks - RVQ codebooks
  • tokens - Token indices [batch, num_codebooks, T]

Returns

Reconstructed waveform [batch, 1, samples].

encode(encoder_fn, params, codebooks, waveform)

@spec encode(function(), map(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Encode waveform to discrete tokens.

Parameters

  • encoder_fn - Compiled encoder prediction function
  • params - Encoder parameters
  • codebooks - RVQ codebooks
  • waveform - Input waveform [batch, 1, samples]

Returns

Token indices [batch, num_codebooks, T].

init_codebooks(num_codebooks, codebook_size, embedding_dim, key \\ Nx.Random.key(42))

@spec init_codebooks(pos_integer(), pos_integer(), pos_integer(), Nx.Tensor.t()) ::
  Nx.Tensor.t()

Initialize RVQ codebooks.

Parameters

  • num_codebooks - Number of codebooks
  • codebook_size - Entries per codebook
  • embedding_dim - Dimension of each embedding
  • key - PRNG key (default: Nx.Random.key(42))

Returns

Codebooks tensor [num_codebooks, codebook_size, embedding_dim].

loss(reconstruction, target, z_e, z_q, opts \\ [])

Combined EnCodec loss.

Parameters

  • reconstruction - Reconstructed waveform
  • target - Original waveform
  • z_e - Encoder output
  • z_q - Quantized output
  • commitment_weight - Weight for commitment loss (default: 0.25)

Returns

Combined loss scalar.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output embedding dimension.

rvq_dequantize(indices, codebooks)

@spec rvq_dequantize(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Dequantize RVQ tokens back to continuous embeddings.

Parameters

  • indices - Token indices [batch, num_codebooks, T]
  • codebooks - RVQ codebooks [num_codebooks, codebook_size, dim]

Returns

Continuous embeddings [batch, T, dim].

rvq_quantize(z_e, codebooks)

@spec rvq_quantize(Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Quantize encoder output using Residual Vector Quantization.

Each codebook quantizes the residual from the previous quantization, building up a coarse-to-fine representation.

Parameters

  • z_e - Encoder output [batch, T, dim]
  • codebooks - RVQ codebooks [num_codebooks, codebook_size, dim]

Returns

{z_q, indices} where:

  • z_q - Quantized vectors [batch, T, dim] (sum of all codebook contributions)
  • indices - Token indices [batch, num_codebooks, T]

spectral_loss(reconstruction, target)

@spec spectral_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute spectral reconstruction loss (multi-scale STFT).

Simplified version using single-scale MSE in time domain. Full EnCodec uses multi-scale spectral loss + adversarial loss.

Parameters

  • reconstruction - Reconstructed waveform [batch, 1, samples]
  • target - Original waveform [batch, 1, samples]

Returns

Reconstruction loss scalar.