Edifice.Generative.LatentDiffusion (Edifice v0.2.0)

Copy Markdown View Source

Latent Diffusion: Diffusion in VAE latent space.

Implements the Latent Diffusion Model (LDM) concept from "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., CVPR 2022). Instead of diffusing in the full input space, LDM first compresses data with a VAE encoder, runs diffusion in the compact latent space, then decodes back.

Key Innovation: Perceptual Compression + Diffusion

Training:
  1. Train VAE: input -> encoder -> z -> decoder -> reconstruction
  2. Freeze VAE
  3. Train diffusion in latent space:
     z_0 = encoder(input)
     z_t = add_noise(z_0, t)
     eps_hat = denoiser(z_t, t)
     loss = MSE(eps, eps_hat)

Inference:
  1. Sample z_T ~ N(0, I) in latent space
  2. Denoise: z_0 = diffusion_sample(z_T)
  3. Decode: output = decoder(z_0)

Advantages

FeatureFull-space DiffusionLatent Diffusion
ComputeO(input_dim) per stepO(latent_dim) per step
QualityGoodGood (perceptual compression)
SpeedSlowFast (smaller dim)
MemoryHighLow

Architecture

Returns a tuple of three models: {encoder, decoder, denoiser}

Input [batch, input_size]
      |
      v
+-------------------+
| Encoder (frozen)  |  -> z_0 [batch, latent_size]
+-------------------+
      |
      v (add noise)
+-------------------+
| Denoiser          |  -> eps_hat [batch, latent_size]
| (z_t, t) -> eps   |
+-------------------+
      |
      v (denoise)
+-------------------+
| Decoder (frozen)  |  -> output [batch, input_size]
+-------------------+

Usage

{encoder, decoder, denoiser} = LatentDiffusion.build(
  input_size: 287,
  latent_size: 32,
  hidden_size: 256,
  num_layers: 4
)

# Train VAE first, then freeze and train denoiser

Reference

Summary

Types

Options for build/1.

Functions

Build a Latent Diffusion Model.

KL divergence for VAE training.

Create diffusion noise schedule.

Get the output size (latent dimension for the denoiser).

Calculate approximate parameter count for the full system.

Get recommended defaults.

Reparameterization trick for the encoder.

Types

build_opt()

@type build_opt() ::
  {:hidden_size, pos_integer()}
  | {:input_size, pos_integer()}
  | {:latent_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_steps, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t(), Axon.t()}

Build a Latent Diffusion Model.

Returns {encoder, decoder, denoiser} where:

  • Encoder: maps input to latent distribution (mu, log_var)
  • Decoder: maps latent vector to reconstructed input
  • Denoiser: predicts noise in latent space given (noisy_z, timestep)

Options

  • :input_size - Input feature dimension (required)
  • :latent_size - Latent space dimension (default: 32)
  • :hidden_size - Hidden dimension for all sub-networks (default: 256)
  • :num_layers - Number of layers in denoiser (default: 4)
  • :num_steps - Number of diffusion timesteps (default: 1000)

Returns

{encoder, decoder, denoiser} - Tuple of Axon models.

build_decoder(input_size, latent_size, hidden_size)

@spec build_decoder(pos_integer(), pos_integer(), pos_integer()) :: Axon.t()

Build the VAE decoder.

Maps a latent vector back to input space.

build_denoiser(latent_size, hidden_size, num_layers, num_steps)

@spec build_denoiser(pos_integer(), pos_integer(), pos_integer(), pos_integer()) ::
  Axon.t()

Build the latent-space denoiser.

Predicts noise from (noisy_z, timestep).

build_encoder(input_size, latent_size, hidden_size)

@spec build_encoder(pos_integer(), pos_integer(), pos_integer()) :: Axon.t()

Build the VAE encoder.

Maps input to latent distribution parameters (mu, log_var).

kl_divergence(mu, log_var)

@spec kl_divergence(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

KL divergence for VAE training.

make_schedule(opts \\ [])

@spec make_schedule(keyword()) :: map()

Create diffusion noise schedule.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size (latent dimension for the denoiser).

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for the full system.

reparameterize(mu, log_var, key)

@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Reparameterization trick for the encoder.

Requires a PRNG key for sampling. Returns {z, new_key}.