# `Edifice.Generative.VQVAE`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/vq_vae.ex#L1)

Vector-Quantized Variational Autoencoder (VQ-VAE).

Instead of learning a continuous latent space like a standard VAE,
VQ-VAE learns a discrete codebook of embedding vectors. The encoder
output is quantized to the nearest codebook entry, producing a
discrete latent representation.

This avoids posterior collapse (a common VAE failure mode) and
produces sharper reconstructions since the decoder receives
high-fidelity codebook vectors rather than noisy samples.

## Architecture

```
Input [batch, input_size]
      |
      v
+-------------------+
| Encoder           |
| Dense layers      |
+-------------------+
      |
      v
z_e [batch, embedding_dim]    (continuous encoder output)
      |
      v
+-------------------+
| Quantize          |
| nearest codebook  |
| vector lookup     |
+-------------------+
      |
      v
z_q [batch, embedding_dim]    (discrete quantized vector)
      |
      v
+-------------------+
| Decoder           |
| Dense layers      |
+-------------------+
      |
      v
Reconstruction [batch, input_size]
```

## Training Losses

VQ-VAE training uses three loss components:
1. **Reconstruction loss**: MSE between input and reconstruction
2. **Codebook loss**: `||sg(z_e) - e||^2` - moves codebook vectors toward encoder outputs
3. **Commitment loss**: `||z_e - sg(e)||^2` - prevents encoder from fluctuating too far from codebook

The straight-through estimator passes gradients from decoder directly to encoder,
bypassing the non-differentiable quantization step.

## Usage

    # Build full VQ-VAE
    {encoder, decoder} = VQVAE.build(input_size: 784, embedding_dim: 64, num_embeddings: 512)

    # Quantize encoder output against codebook
    {z_q, indices} = VQVAE.quantize(z_e, codebook)

    # Training losses
    commit = VQVAE.commitment_loss(z_e, z_q)
    cb = VQVAE.codebook_loss(z_e, z_q)

# `build_opt`

```elixir
@type build_opt() :: {:input_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: {Axon.t(), Axon.t()}
```

Build a complete VQ-VAE (encoder + decoder).

The encoder maps inputs to continuous vectors of size `embedding_dim`.
The decoder reconstructs from quantized codebook vectors of the same size.
Quantization is performed externally via `quantize/2` during training.

## Options
  - `:input_size` - Input feature dimension (required)
  - `:embedding_dim` - Codebook vector dimension (default: 64)
  - `:num_embeddings` - Number of codebook entries (default: 512)
  - `:encoder_sizes` - Hidden layer sizes for encoder (default: [256, 128])
  - `:decoder_sizes` - Hidden layer sizes for decoder (default: [128, 256])
  - `:activation` - Activation function (default: :relu)

## Returns
  `{encoder, decoder}` - Tuple of Axon models.
  - Encoder: `[batch, input_size]` -> `[batch, embedding_dim]`
  - Decoder: `[batch, embedding_dim]` -> `[batch, input_size]`

# `build_decoder`

```elixir
@spec build_decoder(keyword()) :: Axon.t()
```

Build the decoder network.

Maps quantized codebook vectors back to the input space.

## Options
  - `:input_size` or `:output_size` - Reconstruction dimension (required)
  - `:embedding_dim` - Input dimension from codebook (default: 64)
  - `:decoder_sizes` - Hidden layer sizes (default: [128, 256])
  - `:activation` - Activation function (default: :relu)

## Returns
  An Axon model: `[batch, embedding_dim]` -> `[batch, output_size]`.

# `build_encoder`

```elixir
@spec build_encoder(keyword()) :: Axon.t()
```

Build the encoder network.

Maps input to a continuous vector in the embedding space. This output
is then quantized to the nearest codebook vector during training.

## Options
  - `:input_size` - Input feature dimension (required)
  - `:embedding_dim` - Output dimension, must match codebook (default: 64)
  - `:encoder_sizes` - Hidden layer sizes (default: [256, 128])
  - `:activation` - Activation function (default: :relu)

## Returns
  An Axon model: `[batch, input_size]` -> `[batch, embedding_dim]`.

# `codebook_loss`

```elixir
@spec codebook_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Codebook loss: moves codebook vectors toward encoder outputs.

Computed as `mean(||sg(z_e) - z_q||^2)` where `sg` is stop_gradient.
This is equivalent to an EMA update of the codebook embeddings.

## Parameters
  - `z_e` - Encoder output `[batch, embedding_dim]` (will be detached)
  - `z_q` - Quantized vectors `[batch, embedding_dim]`

## Returns
  Codebook loss scalar.

# `commitment_loss`

```elixir
@spec commitment_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Commitment loss: encourages encoder outputs to stay close to codebook vectors.

Computed as `mean(||z_e - sg(z_q)||^2)` where `sg` is stop_gradient.
This prevents the encoder output from growing unboundedly away from
the codebook entries.

## Parameters
  - `z_e` - Encoder output `[batch, embedding_dim]`
  - `z_q` - Quantized vectors `[batch, embedding_dim]` (will be detached)

## Returns
  Commitment loss scalar.

# `init_codebook`

```elixir
@spec init_codebook(pos_integer(), pos_integer(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Initialize a codebook with random normal vectors.

## Parameters
  - `num_embeddings` - Number of codebook entries
  - `embedding_dim` - Dimension of each codebook vector
  - `key` - Optional PRNG key (default: `Nx.Random.key(42)`)

## Returns
  Codebook tensor `[num_embeddings, embedding_dim]` initialized from N(0, 1).

# `loss`

```elixir
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()
```

Combined VQ-VAE loss: reconstruction + codebook + commitment.

## Parameters
  - `reconstruction` - Decoder output `[batch, input_size]`
  - `target` - Original input `[batch, input_size]`
  - `z_e` - Encoder output `[batch, embedding_dim]`
  - `z_q` - Quantized vectors `[batch, embedding_dim]`
  - `commitment_weight` - Weight for commitment loss (default: 0.25)

## Returns
  Combined loss scalar.

# `quantize`

```elixir
@spec quantize(Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Quantize continuous encoder outputs to nearest codebook vectors.

For each encoder output vector z_e, finds the nearest codebook
entry by L2 distance and returns the quantized vector along with
the codebook indices.

Uses the straight-through estimator: gradients flow from z_q to z_e
directly, bypassing the non-differentiable argmin.

## Parameters
  - `z_e` - Encoder output `[batch, embedding_dim]`
  - `codebook` - Codebook embeddings `[num_embeddings, embedding_dim]`

## Returns
  `{z_q, indices}` where:
  - `z_q` - Quantized vectors `[batch, embedding_dim]` (with straight-through gradient)
  - `indices` - Codebook indices `[batch]`

---

*Consult [api-reference.md](api-reference.md) for complete listing*
