# `Edifice.Generative.ConsistencyModel`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/consistency_model.ex#L1)

Consistency Model: Single-step generation via consistency function.

Implements the Consistency Model from "Consistency Models" (Song et al.,
ICML 2023). Learns a function f(x_t, t) that maps any point on a
probability flow ODE trajectory directly to its origin, enabling
single-step generation without iterative denoising.

## Key Innovation: Self-Consistency Property

A consistency function f satisfies: for all t, t' on the same trajectory,
f(x_t, t) = f(x_{t'}, t'). In particular, f(x_t, t) = x_0 for all t.

```
Diffusion (many steps):
  x_T -> x_{T-1} -> ... -> x_1 -> x_0    (N denoising steps)

Consistency Model (one step):
  x_T -> x_0    (single forward pass!)

Or few-step refinement:
  x_T -> x_0' -> add_noise(x_0', t') -> x_0''    (2 steps, better quality)
```

## Training Approaches

1. **Consistency Distillation (CD)**: Distill from a pre-trained diffusion model
2. **Consistency Training (CT)**: Train from scratch without a teacher

Both enforce: f(x_{t+1}, t+1) = f(x_t, t) for adjacent timesteps.

## Architecture

```
Input (x_t, sigma_t)
      |
      v
+-----------------------+
| Skip Connection       |
| c_skip(t) * x_t +    |
| c_out(t) * F(x_t, t) |
+-----------------------+
      |
      v
Output: predicted x_0
```

The skip connection ensures the boundary condition f(x, sigma_min) = x.

## Usage

    model = ConsistencyModel.build(
      input_dim: 64,
      hidden_size: 256,
      num_layers: 4
    )

    # Single-step generation
    x_0 = ConsistencyModel.single_step_sample(model, params, noise)

    # Multi-step refinement
    x_0 = ConsistencyModel.multi_step_sample(model, params, noise, steps: 3)

## Reference

- Paper: "Consistency Models"
- arXiv: https://arxiv.org/abs/2303.01469

# `build_opt`

```elixir
@type build_opt() ::
  {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:sigma_max, float()}
  | {:sigma_min, float()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Consistency Model.

The model learns f(x, sigma) that maps noisy input at any noise level
to the clean data, while maintaining self-consistency across the trajectory.

## Options

  - `:input_dim` - Input feature dimension (required)
  - `:hidden_size` - Hidden dimension (default: 256)
  - `:num_layers` - Number of residual blocks (default: 4)
  - `:sigma_min` - Minimum noise level (default: 0.002)
  - `:sigma_max` - Maximum noise level (default: 80.0)

## Returns

  An Axon model: (noisy_input, sigma) -> predicted clean input.

# `consistency_loss`

```elixir
@spec consistency_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Consistency training loss using pseudo-Huber function.

Enforces f(x_{t+1}, t+1) = f(x_t, t) for adjacent timesteps.
Uses a target network (EMA of the online network) for stability.

The pseudo-Huber loss (Song et al., 2023) is: sqrt(||d||^2 + c^2) - c
where d = f_theta(x_{t+dt}, t+dt) - f_target(x_t, t) and c = 0.00054*sqrt(d_dim).

## Parameters

  - `pred_current` - f_theta(x_{t+dt}, t+dt) from online network
  - `pred_target` - f_target(x_t, t) from target (EMA) network

## Returns

  Scalar loss.

# `noise_schedule`

```elixir
@spec noise_schedule(keyword()) :: Nx.Tensor.t()
```

Generate noise schedule using Karras et al. discretization.

Produces N timesteps: t_i = (eps^{1/rho} + (i-1)/(N-1) * (T^{1/rho} - eps^{1/rho}))^rho
where eps = sigma_min, T = sigma_max, and rho = 7 (default).

## Options
  - `:n_steps` - Number of timesteps (default: 40)
  - `:sigma_min` - Minimum sigma (default: 0.002)
  - `:sigma_max` - Maximum sigma (default: 80.0)
  - `:rho` - Schedule curvature (default: 7)

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a Consistency Model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
