Edifice.Generative.VQVAE (Edifice v0.2.0)

Copy Markdown View Source

Vector-Quantized Variational Autoencoder (VQ-VAE).

Instead of learning a continuous latent space like a standard VAE, VQ-VAE learns a discrete codebook of embedding vectors. The encoder output is quantized to the nearest codebook entry, producing a discrete latent representation.

This avoids posterior collapse (a common VAE failure mode) and produces sharper reconstructions since the decoder receives high-fidelity codebook vectors rather than noisy samples.

Architecture

Input [batch, input_size]
      |
      v
+-------------------+
| Encoder           |
| Dense layers      |
+-------------------+
      |
      v
z_e [batch, embedding_dim]    (continuous encoder output)
      |
      v
+-------------------+
| Quantize          |
| nearest codebook  |
| vector lookup     |
+-------------------+
      |
      v
z_q [batch, embedding_dim]    (discrete quantized vector)
      |
      v
+-------------------+
| Decoder           |
| Dense layers      |
+-------------------+
      |
      v
Reconstruction [batch, input_size]

Training Losses

VQ-VAE training uses three loss components:

  1. Reconstruction loss: MSE between input and reconstruction
  2. Codebook loss: ||sg(z_e) - e||^2 - moves codebook vectors toward encoder outputs
  3. Commitment loss: ||z_e - sg(e)||^2 - prevents encoder from fluctuating too far from codebook

The straight-through estimator passes gradients from decoder directly to encoder, bypassing the non-differentiable quantization step.

Usage

# Build full VQ-VAE
{encoder, decoder} = VQVAE.build(input_size: 784, embedding_dim: 64, num_embeddings: 512)

# Quantize encoder output against codebook
{z_q, indices} = VQVAE.quantize(z_e, codebook)

# Training losses
commit = VQVAE.commitment_loss(z_e, z_q)
cb = VQVAE.codebook_loss(z_e, z_q)

Summary

Types

Options for build/1.

Functions

Build a complete VQ-VAE (encoder + decoder).

Build the decoder network.

Build the encoder network.

Codebook loss: moves codebook vectors toward encoder outputs.

Commitment loss: encourages encoder outputs to stay close to codebook vectors.

Initialize a codebook with random normal vectors.

Combined VQ-VAE loss: reconstruction + codebook + commitment.

Quantize continuous encoder outputs to nearest codebook vectors.

Types

build_opt()

@type build_opt() :: {:input_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build a complete VQ-VAE (encoder + decoder).

The encoder maps inputs to continuous vectors of size embedding_dim. The decoder reconstructs from quantized codebook vectors of the same size. Quantization is performed externally via quantize/2 during training.

Options

  • :input_size - Input feature dimension (required)
  • :embedding_dim - Codebook vector dimension (default: 64)
  • :num_embeddings - Number of codebook entries (default: 512)
  • :encoder_sizes - Hidden layer sizes for encoder (default: [256, 128])
  • :decoder_sizes - Hidden layer sizes for decoder (default: [128, 256])
  • :activation - Activation function (default: :relu)

Returns

{encoder, decoder} - Tuple of Axon models.

  • Encoder: [batch, input_size] -> [batch, embedding_dim]
  • Decoder: [batch, embedding_dim] -> [batch, input_size]

build_decoder(opts \\ [])

@spec build_decoder(keyword()) :: Axon.t()

Build the decoder network.

Maps quantized codebook vectors back to the input space.

Options

  • :input_size or :output_size - Reconstruction dimension (required)
  • :embedding_dim - Input dimension from codebook (default: 64)
  • :decoder_sizes - Hidden layer sizes (default: [128, 256])
  • :activation - Activation function (default: :relu)

Returns

An Axon model: [batch, embedding_dim] -> [batch, output_size].

build_encoder(opts \\ [])

@spec build_encoder(keyword()) :: Axon.t()

Build the encoder network.

Maps input to a continuous vector in the embedding space. This output is then quantized to the nearest codebook vector during training.

Options

  • :input_size - Input feature dimension (required)
  • :embedding_dim - Output dimension, must match codebook (default: 64)
  • :encoder_sizes - Hidden layer sizes (default: [256, 128])
  • :activation - Activation function (default: :relu)

Returns

An Axon model: [batch, input_size] -> [batch, embedding_dim].

codebook_loss(z_e, z_q)

@spec codebook_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Codebook loss: moves codebook vectors toward encoder outputs.

Computed as mean(||sg(z_e) - z_q||^2) where sg is stop_gradient. This is equivalent to an EMA update of the codebook embeddings.

Parameters

  • z_e - Encoder output [batch, embedding_dim] (will be detached)
  • z_q - Quantized vectors [batch, embedding_dim]

Returns

Codebook loss scalar.

commitment_loss(z_e, z_q)

@spec commitment_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Commitment loss: encourages encoder outputs to stay close to codebook vectors.

Computed as mean(||z_e - sg(z_q)||^2) where sg is stop_gradient. This prevents the encoder output from growing unboundedly away from the codebook entries.

Parameters

  • z_e - Encoder output [batch, embedding_dim]
  • z_q - Quantized vectors [batch, embedding_dim] (will be detached)

Returns

Commitment loss scalar.

init_codebook(num_embeddings, embedding_dim, key \\ Nx.Random.key(42))

@spec init_codebook(pos_integer(), pos_integer(), Nx.Tensor.t()) :: Nx.Tensor.t()

Initialize a codebook with random normal vectors.

Parameters

  • num_embeddings - Number of codebook entries
  • embedding_dim - Dimension of each codebook vector
  • key - Optional PRNG key (default: Nx.Random.key(42))

Returns

Codebook tensor [num_embeddings, embedding_dim] initialized from N(0, 1).

loss(reconstruction, target, z_e, z_q, opts \\ [])

Combined VQ-VAE loss: reconstruction + codebook + commitment.

Parameters

  • reconstruction - Decoder output [batch, input_size]
  • target - Original input [batch, input_size]
  • z_e - Encoder output [batch, embedding_dim]
  • z_q - Quantized vectors [batch, embedding_dim]
  • commitment_weight - Weight for commitment loss (default: 0.25)

Returns

Combined loss scalar.

quantize(z_e, codebook)

@spec quantize(Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Quantize continuous encoder outputs to nearest codebook vectors.

For each encoder output vector z_e, finds the nearest codebook entry by L2 distance and returns the quantized vector along with the codebook indices.

Uses the straight-through estimator: gradients flow from z_q to z_e directly, bypassing the non-differentiable argmin.

Parameters

  • z_e - Encoder output [batch, embedding_dim]
  • codebook - Codebook embeddings [num_embeddings, embedding_dim]

Returns

{z_q, indices} where:

  • z_q - Quantized vectors [batch, embedding_dim] (with straight-through gradient)
  • indices - Codebook indices [batch]