Latent Diffusion: Diffusion in VAE latent space.
Implements the Latent Diffusion Model (LDM) concept from "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., CVPR 2022). Instead of diffusing in the full input space, LDM first compresses data with a VAE encoder, runs diffusion in the compact latent space, then decodes back.
Key Innovation: Perceptual Compression + Diffusion
Training:
1. Train VAE: input -> encoder -> z -> decoder -> reconstruction
2. Freeze VAE
3. Train diffusion in latent space:
z_0 = encoder(input)
z_t = add_noise(z_0, t)
eps_hat = denoiser(z_t, t)
loss = MSE(eps, eps_hat)
Inference:
1. Sample z_T ~ N(0, I) in latent space
2. Denoise: z_0 = diffusion_sample(z_T)
3. Decode: output = decoder(z_0)Advantages
| Feature | Full-space Diffusion | Latent Diffusion |
|---|---|---|
| Compute | O(input_dim) per step | O(latent_dim) per step |
| Quality | Good | Good (perceptual compression) |
| Speed | Slow | Fast (smaller dim) |
| Memory | High | Low |
Architecture
Returns a tuple of three models: {encoder, decoder, denoiser}
Input [batch, input_size]
|
v
+-------------------+
| Encoder (frozen) | -> z_0 [batch, latent_size]
+-------------------+
|
v (add noise)
+-------------------+
| Denoiser | -> eps_hat [batch, latent_size]
| (z_t, t) -> eps |
+-------------------+
|
v (denoise)
+-------------------+
| Decoder (frozen) | -> output [batch, input_size]
+-------------------+Usage
{encoder, decoder, denoiser} = LatentDiffusion.build(
input_size: 287,
latent_size: 32,
hidden_size: 256,
num_layers: 4
)
# Train VAE first, then freeze and train denoiserReference
- Paper: "High-Resolution Image Synthesis with Latent Diffusion Models"
- arXiv: https://arxiv.org/abs/2112.10752
Summary
Functions
Build a Latent Diffusion Model.
Build the VAE decoder.
Build the latent-space denoiser.
Build the VAE encoder.
KL divergence for VAE training.
Create diffusion noise schedule.
Get the output size (latent dimension for the denoiser).
Calculate approximate parameter count for the full system.
Get recommended defaults.
Reparameterization trick for the encoder.
Types
@type build_opt() :: {:hidden_size, pos_integer()} | {:input_size, pos_integer()} | {:latent_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_steps, pos_integer()}
Options for build/1.
Functions
Build a Latent Diffusion Model.
Returns {encoder, decoder, denoiser} where:
- Encoder: maps input to latent distribution (mu, log_var)
- Decoder: maps latent vector to reconstructed input
- Denoiser: predicts noise in latent space given (noisy_z, timestep)
Options
:input_size- Input feature dimension (required):latent_size- Latent space dimension (default: 32):hidden_size- Hidden dimension for all sub-networks (default: 256):num_layers- Number of layers in denoiser (default: 4):num_steps- Number of diffusion timesteps (default: 1000)
Returns
{encoder, decoder, denoiser} - Tuple of Axon models.
@spec build_decoder(pos_integer(), pos_integer(), pos_integer()) :: Axon.t()
Build the VAE decoder.
Maps a latent vector back to input space.
@spec build_denoiser(pos_integer(), pos_integer(), pos_integer(), pos_integer()) :: Axon.t()
Build the latent-space denoiser.
Predicts noise from (noisy_z, timestep).
@spec build_encoder(pos_integer(), pos_integer(), pos_integer()) :: Axon.t()
Build the VAE encoder.
Maps input to latent distribution parameters (mu, log_var).
@spec kl_divergence(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
KL divergence for VAE training.
Create diffusion noise schedule.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size (latent dimension for the denoiser).
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for the full system.
@spec recommended_defaults() :: keyword()
Get recommended defaults.
@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Reparameterization trick for the encoder.
Requires a PRNG key for sampling. Returns {z, new_key}.