Edifice.Generative.ScoreSDE (Edifice v0.2.0)

Copy Markdown View Source

Score-based SDE: Unified score matching framework for generative modeling.

Implements the Score SDE framework from "Score-Based Generative Modeling through Stochastic Differential Equations" (Song et al., ICLR 2021). Unifies DDPM, SMLD, and other score-based methods under a single SDE perspective with two main variants:

SDE Variants

VP-SDE (Variance Preserving, generalizes DDPM):

dx = -0.5 * beta(t) * x dt + sqrt(beta(t)) dw
beta(t) = beta_min + t * (beta_max - beta_min)

VE-SDE (Variance Exploding, generalizes SMLD):

dx = sqrt(d[sigma^2(t)]/dt) dw
sigma(t) = sigma_min * (sigma_max/sigma_min)^t

Key Innovation: Score Function

The score function s(x, t) = grad_x log p_t(x) is learned via denoising score matching. Once learned, samples are generated by solving the reverse-time SDE:

dx = [-0.5 * beta(t) * x - beta(t) * s(x, t)] dt + sqrt(beta(t)) dw_rev

Or the probability flow ODE (deterministic):

dx = [-0.5 * beta(t) * x - 0.5 * beta(t) * s(x, t)] dt

Architecture

Input (x_t, t)
      |
      v
+-----------------------+
| Score Network         |
| s_theta(x_t, t)       |
|                       |
| x_proj + time_embed   |
| -> residual blocks    |
| -> score prediction   |
+-----------------------+
      |
      v
Score [batch, input_dim]  (gradient of log density)

Usage

model = ScoreSDE.build(
  input_dim: 64,
  hidden_size: 256,
  num_layers: 4,
  sde_type: :vp
)

# Training: denoising score matching
loss = ScoreSDE.dsm_loss(score_pred, noise, sigma)

Reference

Summary

Types

Options for build/1.

Functions

Build a score network s(x, t) for the Score SDE framework.

Denoising score matching loss.

Get the output size of a Score SDE model.

Calculate approximate parameter count.

Get recommended defaults.

VE-SDE noise schedule.

Compute the noise level sigma at time t for VE-SDE.

Compute the marginal distribution parameters at time t for VP-SDE.

VP-SDE noise schedule.

Types

build_opt()

@type build_opt() ::
  {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:sde_type, :vp | :ve}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a score network s(x, t) for the Score SDE framework.

Options

  • :input_dim - Input feature dimension (required)
  • :hidden_size - Hidden dimension (default: 256)
  • :num_layers - Number of residual blocks (default: 4)
  • :sde_type - SDE variant: :vp or :ve (default: :vp)

Returns

An Axon model that predicts the score s(x, t) = grad_x log p_t(x).

dsm_loss(score_pred, noise, sigma)

@spec dsm_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Denoising score matching loss.

The score at noise level sigma is: s(x_t) = -(x_t - x_0) / sigma^2 Train by minimizing: E[sigma^2 * ||s_theta(x_t, t) + (x_t - x_0)/sigma^2||^2]

Parameters

  • score_pred - Predicted score s_theta(x_t, t)
  • noise - The noise added (epsilon)
  • sigma - Noise standard deviation at this timestep

Returns

Scalar loss.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Score SDE model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count.

ve_schedule(opts \\ [])

@spec ve_schedule(keyword()) :: map()

VE-SDE noise schedule.

sigma(t) = sigma_min * (sigma_max / sigma_min)^t

Options

  • :sigma_min - Minimum sigma (default: 0.01)
  • :sigma_max - Maximum sigma (default: 50.0)

Returns

Schedule map.

ve_sigma(t, schedule)

@spec ve_sigma(Nx.Tensor.t(), map()) :: Nx.Tensor.t()

Compute the noise level sigma at time t for VE-SDE.

sigma(t) = sigma_min * (sigma_max / sigma_min)^t

vp_marginal(t, schedule)

@spec vp_marginal(Nx.Tensor.t(), map()) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Compute the marginal distribution parameters at time t for VP-SDE.

For VP-SDE: p(x_t | x_0) = N(x_t; alpha(t) x_0, sigma(t)^2 I) where alpha(t) = exp(-0.5 * integral(beta, 0, t))

vp_schedule(opts \\ [])

@spec vp_schedule(keyword()) :: map()

VP-SDE noise schedule.

beta(t) = beta_min + t * (beta_max - beta_min)

Options

  • :beta_min - Minimum beta (default: 0.1)
  • :beta_max - Maximum beta (default: 20.0)

Returns

Schedule map with perturbation and marginal functions.