# `Edifice.Generative.VAE`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/vae.ex#L1)

Variational Autoencoder (VAE).

Learns a smooth, continuous latent space by encoding inputs into
distributions (parameterized by mu and log_var) rather than point
estimates. The reparameterization trick enables gradient flow through
the stochastic sampling step, making end-to-end training possible.

## Architecture

```
Input [batch, input_size]
      |
      v
+-------------------+
| Encoder           |
| Dense layers      |
+---+----------+----+
    |          |
    v          v
  mu [L]   log_var [L]     L = latent_size
    |          |
    +----+-----+
         |
   reparameterize
   z = mu + eps * exp(0.5 * log_var)
         |
         v
+-------------------+
| Decoder           |
| Dense layers      |
+-------------------+
      |
      v
Reconstruction [batch, input_size]
```

## Loss

The VAE loss combines reconstruction error with a KL divergence
regularizer that pushes the learned posterior toward a standard
normal prior:

    L = reconstruction_loss + beta * KL(q(z|x) || p(z))

The beta parameter (default 1.0) controls the trade-off. Values
less than 1.0 yield a beta-VAE with better reconstructions at the
cost of a less regular latent space.

## Usage

    # Build full VAE
    {encoder, decoder} = VAE.build(input_size: 784, latent_size: 32)

    # Build encoder only (for inference / embedding)
    encoder = VAE.build_encoder(input_size: 784, latent_size: 32)

    # Reparameterization and loss (in training loop)
    key = Nx.Random.key(System.system_time())
    {z, _key} = VAE.reparameterize(mu, log_var, key)
    kl = VAE.kl_divergence(mu, log_var)
    total = VAE.loss(reconstruction, target, mu, log_var, beta: 1.0)

# `build_opt`

```elixir
@type build_opt() ::
  {:input_size, pos_integer()}
  | {:latent_size, pos_integer()}
  | {:encoder_sizes, [pos_integer()]}
  | {:decoder_sizes, [pos_integer()]}
  | {:activation, atom()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: {Axon.t(), Axon.t()}
```

Build a complete VAE (encoder + decoder).

Returns a tuple `{encoder, decoder}` where the encoder outputs
`{mu, log_var}` via `Axon.container/1` and the decoder reconstructs
from a latent vector.

## Options
  - `:input_size` - Input feature dimension (required)
  - `:latent_size` - Latent space dimension (default: 32)
  - `:encoder_sizes` - Hidden layer sizes for encoder (default: [256, 128])
  - `:decoder_sizes` - Hidden layer sizes for decoder (default: [128, 256])
  - `:activation` - Activation function (default: :relu)

## Returns
  `{encoder, decoder}` - Tuple of Axon models.
  - Encoder: input -> `%{mu: [batch, latent], log_var: [batch, latent]}`
  - Decoder: latent -> `[batch, input_size]`

# `build_decoder`

```elixir
@spec build_decoder(keyword()) :: Axon.t()
```

Build the decoder network.

Maps a latent vector back to the input space.

## Options
  - `:input_size` or `:output_size` - Reconstruction output dimension (required)
  - `:latent_size` - Latent dimension (default: 32)
  - `:decoder_sizes` - Hidden layer sizes (default: [128, 256])
  - `:activation` - Activation function (default: :relu)

## Returns
  An Axon model: `[batch, latent_size]` -> `[batch, output_size]`.

# `build_encoder`

```elixir
@spec build_encoder(keyword()) :: Axon.t()
```

Build the encoder network.

Maps input to a distribution in latent space, parameterized by
mu (mean) and log_var (log variance). Uses `Axon.container/1` to
return both outputs as a map.

## Options
  - `:input_size` - Input feature dimension (required)
  - `:latent_size` - Latent dimension (default: 32)
  - `:encoder_sizes` - Hidden layer sizes (default: [256, 128])
  - `:activation` - Activation function (default: :relu)

## Returns
  An Axon model outputting `%{mu: [batch, latent_size], log_var: [batch, latent_size]}`.

# `kl_divergence`

```elixir
@spec kl_divergence(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

KL divergence between the learned posterior q(z|x) and the prior p(z) = N(0, I).

Computed in closed form:
    KL = -0.5 * sum(1 + log_var - mu^2 - exp(log_var))

## Parameters
  - `mu` - Mean `[batch, latent_size]`
  - `log_var` - Log variance `[batch, latent_size]`

## Returns
  KL divergence scalar (mean over batch).

# `loss`

```elixir
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()
```

Combined VAE loss: reconstruction + beta * KL divergence.

Uses mean squared error for reconstruction by default. The beta
parameter controls the regularization strength:
- beta = 1.0: standard VAE (ELBO)
- beta < 1.0: weaker regularization, better reconstructions
- beta > 1.0: beta-VAE, more disentangled but blurrier

## Parameters
  - `reconstruction` - Decoder output `[batch, input_size]`
  - `target` - Original input `[batch, input_size]`
  - `mu` - Encoder mean `[batch, latent_size]`
  - `log_var` - Encoder log variance `[batch, latent_size]`
  - `beta` - KL weight (default: 1.0)

## Returns
  Combined loss scalar.

# `reparameterize`

```elixir
@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}
```

Reparameterization trick: sample z from q(z|x) = N(mu, sigma^2).

Computes `z = mu + eps * exp(0.5 * log_var)` where `eps ~ N(0, I)`.
This moves the stochasticity outside the computational graph,
allowing gradients to flow through mu and log_var.

## Parameters
  - `mu` - Mean of the approximate posterior `[batch, latent_size]`
  - `log_var` - Log variance of the approximate posterior `[batch, latent_size]`
  - `key` - PRNG key from `Nx.Random.key/1` (required for proper stochastic sampling)

## Returns
  `{z, new_key}` — Sampled latent vector and updated PRNG key.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
