Variational Autoencoder (VAE).
Learns a smooth, continuous latent space by encoding inputs into distributions (parameterized by mu and log_var) rather than point estimates. The reparameterization trick enables gradient flow through the stochastic sampling step, making end-to-end training possible.
Architecture
Input [batch, input_size]
|
v
+-------------------+
| Encoder |
| Dense layers |
+---+----------+----+
| |
v v
mu [L] log_var [L] L = latent_size
| |
+----+-----+
|
reparameterize
z = mu + eps * exp(0.5 * log_var)
|
v
+-------------------+
| Decoder |
| Dense layers |
+-------------------+
|
v
Reconstruction [batch, input_size]Loss
The VAE loss combines reconstruction error with a KL divergence regularizer that pushes the learned posterior toward a standard normal prior:
L = reconstruction_loss + beta * KL(q(z|x) || p(z))The beta parameter (default 1.0) controls the trade-off. Values less than 1.0 yield a beta-VAE with better reconstructions at the cost of a less regular latent space.
Usage
# Build full VAE
{encoder, decoder} = VAE.build(input_size: 784, latent_size: 32)
# Build encoder only (for inference / embedding)
encoder = VAE.build_encoder(input_size: 784, latent_size: 32)
# Reparameterization and loss (in training loop)
key = Nx.Random.key(System.system_time())
{z, _key} = VAE.reparameterize(mu, log_var, key)
kl = VAE.kl_divergence(mu, log_var)
total = VAE.loss(reconstruction, target, mu, log_var, beta: 1.0)
Summary
Functions
Build a complete VAE (encoder + decoder).
Build the decoder network.
Build the encoder network.
KL divergence between the learned posterior q(z|x) and the prior p(z) = N(0, I).
Combined VAE loss: reconstruction + beta * KL divergence.
Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).
Types
@type build_opt() :: {:input_size, pos_integer()} | {:latent_size, pos_integer()} | {:encoder_sizes, [pos_integer()]} | {:decoder_sizes, [pos_integer()]} | {:activation, atom()}
Options for build/1.
Functions
Build a complete VAE (encoder + decoder).
Returns a tuple {encoder, decoder} where the encoder outputs
{mu, log_var} via Axon.container/1 and the decoder reconstructs
from a latent vector.
Options
:input_size- Input feature dimension (required):latent_size- Latent space dimension (default: 32):encoder_sizes- Hidden layer sizes for encoder (default: [256, 128]):decoder_sizes- Hidden layer sizes for decoder (default: [128, 256]):activation- Activation function (default: :relu)
Returns
{encoder, decoder} - Tuple of Axon models.
- Encoder: input ->
%{mu: [batch, latent], log_var: [batch, latent]} - Decoder: latent ->
[batch, input_size]
Build the decoder network.
Maps a latent vector back to the input space.
Options
:input_sizeor:output_size- Reconstruction output dimension (required):latent_size- Latent dimension (default: 32):decoder_sizes- Hidden layer sizes (default: [128, 256]):activation- Activation function (default: :relu)
Returns
An Axon model: [batch, latent_size] -> [batch, output_size].
Build the encoder network.
Maps input to a distribution in latent space, parameterized by
mu (mean) and log_var (log variance). Uses Axon.container/1 to
return both outputs as a map.
Options
:input_size- Input feature dimension (required):latent_size- Latent dimension (default: 32):encoder_sizes- Hidden layer sizes (default: [256, 128]):activation- Activation function (default: :relu)
Returns
An Axon model outputting %{mu: [batch, latent_size], log_var: [batch, latent_size]}.
@spec kl_divergence(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
KL divergence between the learned posterior q(z|x) and the prior p(z) = N(0, I).
Computed in closed form:
KL = -0.5 * sum(1 + log_var - mu^2 - exp(log_var))Parameters
mu- Mean[batch, latent_size]log_var- Log variance[batch, latent_size]
Returns
KL divergence scalar (mean over batch).
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Combined VAE loss: reconstruction + beta * KL divergence.
Uses mean squared error for reconstruction by default. The beta parameter controls the regularization strength:
- beta = 1.0: standard VAE (ELBO)
- beta < 1.0: weaker regularization, better reconstructions
- beta > 1.0: beta-VAE, more disentangled but blurrier
Parameters
reconstruction- Decoder output[batch, input_size]target- Original input[batch, input_size]mu- Encoder mean[batch, latent_size]log_var- Encoder log variance[batch, latent_size]beta- KL weight (default: 1.0)
Returns
Combined loss scalar.
@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).
Computes z = mu + eps * exp(0.5 * log_var) where eps ~ N(0, I).
This moves the stochasticity outside the computational graph,
allowing gradients to flow through mu and log_var.
Parameters
mu- Mean of the approximate posterior[batch, latent_size]log_var- Log variance of the approximate posterior[batch, latent_size]key- PRNG key fromNx.Random.key/1(required for proper stochastic sampling)
Returns
{z, new_key} — Sampled latent vector and updated PRNG key.