Edifice.Generative.DDIM (Edifice v0.2.0)

Copy Markdown View Source

DDIM: Denoising Diffusion Implicit Models.

Implements DDIM from "Denoising Diffusion Implicit Models" (Song et al., ICLR 2021). Uses the same training objective as DDPM but enables deterministic sampling with far fewer steps via a non-Markovian reverse process.

Key Innovation: Deterministic Sampling

DDPM requires ~1000 steps because each step adds stochastic noise. DDIM reformulates the reverse process as deterministic:

DDPM (stochastic, ~1000 steps):
  x_{t-1} = mu(x_t, eps_theta) + sigma_t * z,  z ~ N(0, I)

DDIM (deterministic, ~50 steps):
  x_{t-1} = sqrt(alpha_{t-1}) * pred_x0 + sqrt(1-alpha_{t-1}) * pred_dir
  where:
    pred_x0 = (x_t - sqrt(1-alpha_t) * eps_theta) / sqrt(alpha_t)
    pred_dir = sqrt(1-alpha_{t-1}) * eps_theta

The eta parameter interpolates between deterministic (eta=0) and stochastic DDPM (eta=1).

Architecture

Uses the same denoising network as DDPM (noise predictor conditioned on timestep and observations). The difference is in sampling, not training.

Same training as DDPM:
  1. Sample x_0, noise, timestep
  2. Compute noisy x_t
  3. Predict noise: eps_theta(x_t, t, obs)
  4. Loss: MSE(eps, eps_theta)

DDIM sampling (fewer steps):
  1. Choose stride S (e.g., skip every 20 steps)
  2. For t in [1000, 980, 960, ..., 20, 0]:
     x_{t-S} = ddim_step(x_t, t, eps_theta)
  3. Return x_0

Usage

# Build denoising network (same as DDPM)
model = DDIM.build(
  obs_size: 287,
  action_dim: 64,
  action_horizon: 8,
  num_steps: 1000
)

# DDIM sampling with fewer steps
schedule = DDIM.make_schedule(num_steps: 1000)
actions = DDIM.ddim_sample(params, predict_fn, obs, noise,
  schedule: schedule,
  ddim_steps: 50,
  eta: 0.0
)

Reference

Summary

Types

Options for build/1.

Functions

Build a DDIM denoising network.

DDIM sampling: deterministic reverse process with stride.

Precompute diffusion schedule (same as DDPM).

Get the output size of a DDIM model.

Calculate approximate parameter count.

Get recommended defaults for fast DDIM sampling.

Types

build_opt()

@type build_opt() ::
  {:action_dim, pos_integer()}
  | {:action_horizon, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_steps, pos_integer()}
  | {:obs_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a DDIM denoising network.

The network architecture is identical to DDPM -- the difference is in the sampling procedure, not the model.

Options

  • :obs_size - Size of observation embedding (required)
  • :action_dim - Dimension of action space (required)
  • :action_horizon - Number of actions to predict (default: 8)
  • :hidden_size - Hidden dimension (default: 256)
  • :num_layers - Number of denoiser layers (default: 4)
  • :num_steps - Number of diffusion timesteps (default: 1000)

Returns

An Axon model that predicts noise given (noisy_actions, timestep, obs).

ddim_sample(params, predict_fn, observations, initial_noise, opts \\ [])

@spec ddim_sample(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

DDIM sampling: deterministic reverse process with stride.

Parameters

  • params - Model parameters
  • predict_fn - Noise prediction function
  • observations - Conditioning [batch, obs_size]
  • initial_noise - Starting noise [batch, action_horizon, action_dim]
  • opts:
    • :schedule - Precomputed schedule from make_schedule/1
    • :ddim_steps - Number of DDIM steps (default: 50)
    • :eta - Stochasticity: 0.0 = deterministic, 1.0 = DDPM (default: 0.0)

Returns

Denoised actions [batch, action_horizon, action_dim].

ddim_step(x_t, predicted_noise, t, t_prev, schedule, eta)

@spec ddim_step(Nx.Tensor.t(), Nx.Tensor.t(), integer(), integer(), map(), float()) ::
  Nx.Tensor.t()

Single DDIM reverse step.

pred_x0 = (x_t - sqrt(1-alpha_t) * eps) / sqrt(alpha_t)
direction = sqrt(1-alpha_{t-1} - sigma^2) * eps
x_{t-1} = sqrt(alpha_{t-1}) * pred_x0 + direction + sigma * noise

make_schedule(opts \\ [])

@spec make_schedule(keyword()) :: map()

Precompute diffusion schedule (same as DDPM).

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a DDIM model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count.