Edifice.Generative.VAE (Edifice v0.2.0)

Copy Markdown View Source

Variational Autoencoder (VAE).

Learns a smooth, continuous latent space by encoding inputs into distributions (parameterized by mu and log_var) rather than point estimates. The reparameterization trick enables gradient flow through the stochastic sampling step, making end-to-end training possible.

Architecture

Input [batch, input_size]
      |
      v
+-------------------+
| Encoder           |
| Dense layers      |
+---+----------+----+
    |          |
    v          v
  mu [L]   log_var [L]     L = latent_size
    |          |
    +----+-----+
         |
   reparameterize
   z = mu + eps * exp(0.5 * log_var)
         |
         v
+-------------------+
| Decoder           |
| Dense layers      |
+-------------------+
      |
      v
Reconstruction [batch, input_size]

Loss

The VAE loss combines reconstruction error with a KL divergence regularizer that pushes the learned posterior toward a standard normal prior:

L = reconstruction_loss + beta * KL(q(z|x) || p(z))

The beta parameter (default 1.0) controls the trade-off. Values less than 1.0 yield a beta-VAE with better reconstructions at the cost of a less regular latent space.

Usage

# Build full VAE
{encoder, decoder} = VAE.build(input_size: 784, latent_size: 32)

# Build encoder only (for inference / embedding)
encoder = VAE.build_encoder(input_size: 784, latent_size: 32)

# Reparameterization and loss (in training loop)
key = Nx.Random.key(System.system_time())
{z, _key} = VAE.reparameterize(mu, log_var, key)
kl = VAE.kl_divergence(mu, log_var)
total = VAE.loss(reconstruction, target, mu, log_var, beta: 1.0)

Summary

Types

Options for build/1.

Functions

Build a complete VAE (encoder + decoder).

Build the decoder network.

Build the encoder network.

KL divergence between the learned posterior q(z|x) and the prior p(z) = N(0, I).

Combined VAE loss: reconstruction + beta * KL divergence.

Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).

Types

build_opt()

@type build_opt() ::
  {:input_size, pos_integer()}
  | {:latent_size, pos_integer()}
  | {:encoder_sizes, [pos_integer()]}
  | {:decoder_sizes, [pos_integer()]}
  | {:activation, atom()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build a complete VAE (encoder + decoder).

Returns a tuple {encoder, decoder} where the encoder outputs {mu, log_var} via Axon.container/1 and the decoder reconstructs from a latent vector.

Options

  • :input_size - Input feature dimension (required)
  • :latent_size - Latent space dimension (default: 32)
  • :encoder_sizes - Hidden layer sizes for encoder (default: [256, 128])
  • :decoder_sizes - Hidden layer sizes for decoder (default: [128, 256])
  • :activation - Activation function (default: :relu)

Returns

{encoder, decoder} - Tuple of Axon models.

  • Encoder: input -> %{mu: [batch, latent], log_var: [batch, latent]}
  • Decoder: latent -> [batch, input_size]

build_decoder(opts \\ [])

@spec build_decoder(keyword()) :: Axon.t()

Build the decoder network.

Maps a latent vector back to the input space.

Options

  • :input_size or :output_size - Reconstruction output dimension (required)
  • :latent_size - Latent dimension (default: 32)
  • :decoder_sizes - Hidden layer sizes (default: [128, 256])
  • :activation - Activation function (default: :relu)

Returns

An Axon model: [batch, latent_size] -> [batch, output_size].

build_encoder(opts \\ [])

@spec build_encoder(keyword()) :: Axon.t()

Build the encoder network.

Maps input to a distribution in latent space, parameterized by mu (mean) and log_var (log variance). Uses Axon.container/1 to return both outputs as a map.

Options

  • :input_size - Input feature dimension (required)
  • :latent_size - Latent dimension (default: 32)
  • :encoder_sizes - Hidden layer sizes (default: [256, 128])
  • :activation - Activation function (default: :relu)

Returns

An Axon model outputting %{mu: [batch, latent_size], log_var: [batch, latent_size]}.

kl_divergence(mu, log_var)

@spec kl_divergence(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

KL divergence between the learned posterior q(z|x) and the prior p(z) = N(0, I).

Computed in closed form:

KL = -0.5 * sum(1 + log_var - mu^2 - exp(log_var))

Parameters

  • mu - Mean [batch, latent_size]
  • log_var - Log variance [batch, latent_size]

Returns

KL divergence scalar (mean over batch).

loss(reconstruction, target, mu, log_var, opts \\ [])

Combined VAE loss: reconstruction + beta * KL divergence.

Uses mean squared error for reconstruction by default. The beta parameter controls the regularization strength:

  • beta = 1.0: standard VAE (ELBO)
  • beta < 1.0: weaker regularization, better reconstructions
  • beta > 1.0: beta-VAE, more disentangled but blurrier

Parameters

  • reconstruction - Decoder output [batch, input_size]
  • target - Original input [batch, input_size]
  • mu - Encoder mean [batch, latent_size]
  • log_var - Encoder log variance [batch, latent_size]
  • beta - KL weight (default: 1.0)

Returns

Combined loss scalar.

reparameterize(mu, log_var, key)

@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).

Computes z = mu + eps * exp(0.5 * log_var) where eps ~ N(0, I). This moves the stochasticity outside the computational graph, allowing gradients to flow through mu and log_var.

Parameters

  • mu - Mean of the approximate posterior [batch, latent_size]
  • log_var - Log variance of the approximate posterior [batch, latent_size]
  • key - PRNG key from Nx.Random.key/1 (required for proper stochastic sampling)

Returns

{z, new_key} — Sampled latent vector and updated PRNG key.