# `Edifice.Generative.LatentDiffusion`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/latent_diffusion.ex#L1)

Latent Diffusion: Diffusion in VAE latent space.

Implements the Latent Diffusion Model (LDM) concept from "High-Resolution
Image Synthesis with Latent Diffusion Models" (Rombach et al., CVPR 2022).
Instead of diffusing in the full input space, LDM first compresses data
with a VAE encoder, runs diffusion in the compact latent space, then
decodes back.

## Key Innovation: Perceptual Compression + Diffusion

```
Training:
  1. Train VAE: input -> encoder -> z -> decoder -> reconstruction
  2. Freeze VAE
  3. Train diffusion in latent space:
     z_0 = encoder(input)
     z_t = add_noise(z_0, t)
     eps_hat = denoiser(z_t, t)
     loss = MSE(eps, eps_hat)

Inference:
  1. Sample z_T ~ N(0, I) in latent space
  2. Denoise: z_0 = diffusion_sample(z_T)
  3. Decode: output = decoder(z_0)
```

## Advantages

| Feature | Full-space Diffusion | Latent Diffusion |
|---------|---------------------|-----------------|
| Compute | O(input_dim) per step | O(latent_dim) per step |
| Quality | Good | Good (perceptual compression) |
| Speed | Slow | Fast (smaller dim) |
| Memory | High | Low |

## Architecture

Returns a tuple of three models: `{encoder, decoder, denoiser}`

```
Input [batch, input_size]
      |
      v
+-------------------+
| Encoder (frozen)  |  -> z_0 [batch, latent_size]
+-------------------+
      |
      v (add noise)
+-------------------+
| Denoiser          |  -> eps_hat [batch, latent_size]
| (z_t, t) -> eps   |
+-------------------+
      |
      v (denoise)
+-------------------+
| Decoder (frozen)  |  -> output [batch, input_size]
+-------------------+
```

## Usage

    {encoder, decoder, denoiser} = LatentDiffusion.build(
      input_size: 287,
      latent_size: 32,
      hidden_size: 256,
      num_layers: 4
    )

    # Train VAE first, then freeze and train denoiser

## Reference

- Paper: "High-Resolution Image Synthesis with Latent Diffusion Models"
- arXiv: https://arxiv.org/abs/2112.10752

# `build_opt`

```elixir
@type build_opt() ::
  {:hidden_size, pos_integer()}
  | {:input_size, pos_integer()}
  | {:latent_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_steps, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: {Axon.t(), Axon.t(), Axon.t()}
```

Build a Latent Diffusion Model.

Returns `{encoder, decoder, denoiser}` where:
- Encoder: maps input to latent distribution (mu, log_var)
- Decoder: maps latent vector to reconstructed input
- Denoiser: predicts noise in latent space given (noisy_z, timestep)

## Options

  - `:input_size` - Input feature dimension (required)
  - `:latent_size` - Latent space dimension (default: 32)
  - `:hidden_size` - Hidden dimension for all sub-networks (default: 256)
  - `:num_layers` - Number of layers in denoiser (default: 4)
  - `:num_steps` - Number of diffusion timesteps (default: 1000)

## Returns

  `{encoder, decoder, denoiser}` - Tuple of Axon models.

# `build_decoder`

```elixir
@spec build_decoder(pos_integer(), pos_integer(), pos_integer()) :: Axon.t()
```

Build the VAE decoder.

Maps a latent vector back to input space.

# `build_denoiser`

```elixir
@spec build_denoiser(pos_integer(), pos_integer(), pos_integer(), pos_integer()) ::
  Axon.t()
```

Build the latent-space denoiser.

Predicts noise from (noisy_z, timestep).

# `build_encoder`

```elixir
@spec build_encoder(pos_integer(), pos_integer(), pos_integer()) :: Axon.t()
```

Build the VAE encoder.

Maps input to latent distribution parameters (mu, log_var).

# `kl_divergence`

```elixir
@spec kl_divergence(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

KL divergence for VAE training.

# `make_schedule`

```elixir
@spec make_schedule(keyword()) :: map()
```

Create diffusion noise schedule.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size (latent dimension for the denoiser).

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for the full system.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults.

# `reparameterize`

```elixir
@spec reparameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}
```

Reparameterization trick for the encoder.

Requires a PRNG `key` for sampling. Returns `{z, new_key}`.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
