Consistency Model: Single-step generation via consistency function.
Implements the Consistency Model from "Consistency Models" (Song et al., ICML 2023). Learns a function f(x_t, t) that maps any point on a probability flow ODE trajectory directly to its origin, enabling single-step generation without iterative denoising.
Key Innovation: Self-Consistency Property
A consistency function f satisfies: for all t, t' on the same trajectory, f(xt, t) = f(x{t'}, t'). In particular, f(x_t, t) = x_0 for all t.
Diffusion (many steps):
x_T -> x_{T-1} -> ... -> x_1 -> x_0 (N denoising steps)
Consistency Model (one step):
x_T -> x_0 (single forward pass!)
Or few-step refinement:
x_T -> x_0' -> add_noise(x_0', t') -> x_0'' (2 steps, better quality)Training Approaches
- Consistency Distillation (CD): Distill from a pre-trained diffusion model
- Consistency Training (CT): Train from scratch without a teacher
Both enforce: f(x_{t+1}, t+1) = f(x_t, t) for adjacent timesteps.
Architecture
Input (x_t, sigma_t)
|
v
+-----------------------+
| Skip Connection |
| c_skip(t) * x_t + |
| c_out(t) * F(x_t, t) |
+-----------------------+
|
v
Output: predicted x_0The skip connection ensures the boundary condition f(x, sigma_min) = x.
Usage
model = ConsistencyModel.build(
input_dim: 64,
hidden_size: 256,
num_layers: 4
)
# Single-step generation
x_0 = ConsistencyModel.single_step_sample(model, params, noise)
# Multi-step refinement
x_0 = ConsistencyModel.multi_step_sample(model, params, noise, steps: 3)Reference
- Paper: "Consistency Models"
- arXiv: https://arxiv.org/abs/2303.01469
Summary
Functions
Build a Consistency Model.
Consistency training loss using pseudo-Huber function.
Generate noise schedule using Karras et al. discretization.
Get the output size of a Consistency Model.
Calculate approximate parameter count.
Get recommended defaults.
Types
@type build_opt() :: {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:num_layers, pos_integer()} | {:sigma_max, float()} | {:sigma_min, float()}
Options for build/1.
Functions
Build a Consistency Model.
The model learns f(x, sigma) that maps noisy input at any noise level to the clean data, while maintaining self-consistency across the trajectory.
Options
:input_dim- Input feature dimension (required):hidden_size- Hidden dimension (default: 256):num_layers- Number of residual blocks (default: 4):sigma_min- Minimum noise level (default: 0.002):sigma_max- Maximum noise level (default: 80.0)
Returns
An Axon model: (noisy_input, sigma) -> predicted clean input.
@spec consistency_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Consistency training loss using pseudo-Huber function.
Enforces f(x_{t+1}, t+1) = f(x_t, t) for adjacent timesteps. Uses a target network (EMA of the online network) for stability.
The pseudo-Huber loss (Song et al., 2023) is: sqrt(||d||^2 + c^2) - c where d = ftheta(x{t+dt}, t+dt) - f_target(x_t, t) and c = 0.00054*sqrt(d_dim).
Parameters
pred_current- ftheta(x{t+dt}, t+dt) from online networkpred_target- f_target(x_t, t) from target (EMA) network
Returns
Scalar loss.
@spec noise_schedule(keyword()) :: Nx.Tensor.t()
Generate noise schedule using Karras et al. discretization.
Produces N timesteps: t_i = (eps^{1/rho} + (i-1)/(N-1) * (T^{1/rho} - eps^{1/rho}))^rho where eps = sigma_min, T = sigma_max, and rho = 7 (default).
Options
:n_steps- Number of timesteps (default: 40):sigma_min- Minimum sigma (default: 0.002):sigma_max- Maximum sigma (default: 80.0):rho- Schedule curvature (default: 7)
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Consistency Model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count.
@spec recommended_defaults() :: keyword()
Get recommended defaults.