Vector-Quantized Variational Autoencoder (VQ-VAE).
Instead of learning a continuous latent space like a standard VAE, VQ-VAE learns a discrete codebook of embedding vectors. The encoder output is quantized to the nearest codebook entry, producing a discrete latent representation.
This avoids posterior collapse (a common VAE failure mode) and produces sharper reconstructions since the decoder receives high-fidelity codebook vectors rather than noisy samples.
Architecture
Input [batch, input_size]
|
v
+-------------------+
| Encoder |
| Dense layers |
+-------------------+
|
v
z_e [batch, embedding_dim] (continuous encoder output)
|
v
+-------------------+
| Quantize |
| nearest codebook |
| vector lookup |
+-------------------+
|
v
z_q [batch, embedding_dim] (discrete quantized vector)
|
v
+-------------------+
| Decoder |
| Dense layers |
+-------------------+
|
v
Reconstruction [batch, input_size]Training Losses
VQ-VAE training uses three loss components:
- Reconstruction loss: MSE between input and reconstruction
- Codebook loss:
||sg(z_e) - e||^2- moves codebook vectors toward encoder outputs - Commitment loss:
||z_e - sg(e)||^2- prevents encoder from fluctuating too far from codebook
The straight-through estimator passes gradients from decoder directly to encoder, bypassing the non-differentiable quantization step.
Usage
# Build full VQ-VAE
{encoder, decoder} = VQVAE.build(input_size: 784, embedding_dim: 64, num_embeddings: 512)
# Quantize encoder output against codebook
{z_q, indices} = VQVAE.quantize(z_e, codebook)
# Training losses
commit = VQVAE.commitment_loss(z_e, z_q)
cb = VQVAE.codebook_loss(z_e, z_q)
Summary
Functions
Build a complete VQ-VAE (encoder + decoder).
Build the decoder network.
Build the encoder network.
Codebook loss: moves codebook vectors toward encoder outputs.
Commitment loss: encourages encoder outputs to stay close to codebook vectors.
Initialize a codebook with random normal vectors.
Combined VQ-VAE loss: reconstruction + codebook + commitment.
Quantize continuous encoder outputs to nearest codebook vectors.
Types
@type build_opt() :: {:input_size, pos_integer()}
Options for build/1.
Functions
Build a complete VQ-VAE (encoder + decoder).
The encoder maps inputs to continuous vectors of size embedding_dim.
The decoder reconstructs from quantized codebook vectors of the same size.
Quantization is performed externally via quantize/2 during training.
Options
:input_size- Input feature dimension (required):embedding_dim- Codebook vector dimension (default: 64):num_embeddings- Number of codebook entries (default: 512):encoder_sizes- Hidden layer sizes for encoder (default: [256, 128]):decoder_sizes- Hidden layer sizes for decoder (default: [128, 256]):activation- Activation function (default: :relu)
Returns
{encoder, decoder} - Tuple of Axon models.
- Encoder:
[batch, input_size]->[batch, embedding_dim] - Decoder:
[batch, embedding_dim]->[batch, input_size]
Build the decoder network.
Maps quantized codebook vectors back to the input space.
Options
:input_sizeor:output_size- Reconstruction dimension (required):embedding_dim- Input dimension from codebook (default: 64):decoder_sizes- Hidden layer sizes (default: [128, 256]):activation- Activation function (default: :relu)
Returns
An Axon model: [batch, embedding_dim] -> [batch, output_size].
Build the encoder network.
Maps input to a continuous vector in the embedding space. This output is then quantized to the nearest codebook vector during training.
Options
:input_size- Input feature dimension (required):embedding_dim- Output dimension, must match codebook (default: 64):encoder_sizes- Hidden layer sizes (default: [256, 128]):activation- Activation function (default: :relu)
Returns
An Axon model: [batch, input_size] -> [batch, embedding_dim].
@spec codebook_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Codebook loss: moves codebook vectors toward encoder outputs.
Computed as mean(||sg(z_e) - z_q||^2) where sg is stop_gradient.
This is equivalent to an EMA update of the codebook embeddings.
Parameters
z_e- Encoder output[batch, embedding_dim](will be detached)z_q- Quantized vectors[batch, embedding_dim]
Returns
Codebook loss scalar.
@spec commitment_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Commitment loss: encourages encoder outputs to stay close to codebook vectors.
Computed as mean(||z_e - sg(z_q)||^2) where sg is stop_gradient.
This prevents the encoder output from growing unboundedly away from
the codebook entries.
Parameters
z_e- Encoder output[batch, embedding_dim]z_q- Quantized vectors[batch, embedding_dim](will be detached)
Returns
Commitment loss scalar.
@spec init_codebook(pos_integer(), pos_integer(), Nx.Tensor.t()) :: Nx.Tensor.t()
Initialize a codebook with random normal vectors.
Parameters
num_embeddings- Number of codebook entriesembedding_dim- Dimension of each codebook vectorkey- Optional PRNG key (default:Nx.Random.key(42))
Returns
Codebook tensor [num_embeddings, embedding_dim] initialized from N(0, 1).
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Combined VQ-VAE loss: reconstruction + codebook + commitment.
Parameters
reconstruction- Decoder output[batch, input_size]target- Original input[batch, input_size]z_e- Encoder output[batch, embedding_dim]z_q- Quantized vectors[batch, embedding_dim]commitment_weight- Weight for commitment loss (default: 0.25)
Returns
Combined loss scalar.
@spec quantize(Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Quantize continuous encoder outputs to nearest codebook vectors.
For each encoder output vector z_e, finds the nearest codebook entry by L2 distance and returns the quantized vector along with the codebook indices.
Uses the straight-through estimator: gradients flow from z_q to z_e directly, bypassing the non-differentiable argmin.
Parameters
z_e- Encoder output[batch, embedding_dim]codebook- Codebook embeddings[num_embeddings, embedding_dim]
Returns
{z_q, indices} where:
z_q- Quantized vectors[batch, embedding_dim](with straight-through gradient)indices- Codebook indices[batch]