Edifice.Generative.ConsistencyModel (Edifice v0.2.0)

Copy Markdown View Source

Consistency Model: Single-step generation via consistency function.

Implements the Consistency Model from "Consistency Models" (Song et al., ICML 2023). Learns a function f(x_t, t) that maps any point on a probability flow ODE trajectory directly to its origin, enabling single-step generation without iterative denoising.

Key Innovation: Self-Consistency Property

A consistency function f satisfies: for all t, t' on the same trajectory, f(xt, t) = f(x{t'}, t'). In particular, f(x_t, t) = x_0 for all t.

Diffusion (many steps):
  x_T -> x_{T-1} -> ... -> x_1 -> x_0    (N denoising steps)

Consistency Model (one step):
  x_T -> x_0    (single forward pass!)

Or few-step refinement:
  x_T -> x_0' -> add_noise(x_0', t') -> x_0''    (2 steps, better quality)

Training Approaches

  1. Consistency Distillation (CD): Distill from a pre-trained diffusion model
  2. Consistency Training (CT): Train from scratch without a teacher

Both enforce: f(x_{t+1}, t+1) = f(x_t, t) for adjacent timesteps.

Architecture

Input (x_t, sigma_t)
      |
      v
+-----------------------+
| Skip Connection       |
| c_skip(t) * x_t +    |
| c_out(t) * F(x_t, t) |
+-----------------------+
      |
      v
Output: predicted x_0

The skip connection ensures the boundary condition f(x, sigma_min) = x.

Usage

model = ConsistencyModel.build(
  input_dim: 64,
  hidden_size: 256,
  num_layers: 4
)

# Single-step generation
x_0 = ConsistencyModel.single_step_sample(model, params, noise)

# Multi-step refinement
x_0 = ConsistencyModel.multi_step_sample(model, params, noise, steps: 3)

Reference

Summary

Types

Options for build/1.

Functions

Build a Consistency Model.

Consistency training loss using pseudo-Huber function.

Generate noise schedule using Karras et al. discretization.

Get the output size of a Consistency Model.

Calculate approximate parameter count.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:sigma_max, float()}
  | {:sigma_min, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Consistency Model.

The model learns f(x, sigma) that maps noisy input at any noise level to the clean data, while maintaining self-consistency across the trajectory.

Options

  • :input_dim - Input feature dimension (required)
  • :hidden_size - Hidden dimension (default: 256)
  • :num_layers - Number of residual blocks (default: 4)
  • :sigma_min - Minimum noise level (default: 0.002)
  • :sigma_max - Maximum noise level (default: 80.0)

Returns

An Axon model: (noisy_input, sigma) -> predicted clean input.

consistency_loss(pred_current, pred_target)

@spec consistency_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Consistency training loss using pseudo-Huber function.

Enforces f(x_{t+1}, t+1) = f(x_t, t) for adjacent timesteps. Uses a target network (EMA of the online network) for stability.

The pseudo-Huber loss (Song et al., 2023) is: sqrt(||d||^2 + c^2) - c where d = ftheta(x{t+dt}, t+dt) - f_target(x_t, t) and c = 0.00054*sqrt(d_dim).

Parameters

  • pred_current - ftheta(x{t+dt}, t+dt) from online network
  • pred_target - f_target(x_t, t) from target (EMA) network

Returns

Scalar loss.

noise_schedule(opts \\ [])

@spec noise_schedule(keyword()) :: Nx.Tensor.t()

Generate noise schedule using Karras et al. discretization.

Produces N timesteps: t_i = (eps^{1/rho} + (i-1)/(N-1) * (T^{1/rho} - eps^{1/rho}))^rho where eps = sigma_min, T = sigma_max, and rho = 7 (default).

Options

  • :n_steps - Number of timesteps (default: 40)
  • :sigma_min - Minimum sigma (default: 0.002)
  • :sigma_max - Maximum sigma (default: 80.0)
  • :rho - Schedule curvature (default: 7)

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Consistency Model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count.