Score-based SDE: Unified score matching framework for generative modeling.
Implements the Score SDE framework from "Score-Based Generative Modeling through Stochastic Differential Equations" (Song et al., ICLR 2021). Unifies DDPM, SMLD, and other score-based methods under a single SDE perspective with two main variants:
SDE Variants
VP-SDE (Variance Preserving, generalizes DDPM):
dx = -0.5 * beta(t) * x dt + sqrt(beta(t)) dw
beta(t) = beta_min + t * (beta_max - beta_min)VE-SDE (Variance Exploding, generalizes SMLD):
dx = sqrt(d[sigma^2(t)]/dt) dw
sigma(t) = sigma_min * (sigma_max/sigma_min)^tKey Innovation: Score Function
The score function s(x, t) = grad_x log p_t(x) is learned via denoising score matching. Once learned, samples are generated by solving the reverse-time SDE:
dx = [-0.5 * beta(t) * x - beta(t) * s(x, t)] dt + sqrt(beta(t)) dw_revOr the probability flow ODE (deterministic):
dx = [-0.5 * beta(t) * x - 0.5 * beta(t) * s(x, t)] dtArchitecture
Input (x_t, t)
|
v
+-----------------------+
| Score Network |
| s_theta(x_t, t) |
| |
| x_proj + time_embed |
| -> residual blocks |
| -> score prediction |
+-----------------------+
|
v
Score [batch, input_dim] (gradient of log density)Usage
model = ScoreSDE.build(
input_dim: 64,
hidden_size: 256,
num_layers: 4,
sde_type: :vp
)
# Training: denoising score matching
loss = ScoreSDE.dsm_loss(score_pred, noise, sigma)Reference
- Paper: "Score-Based Generative Modeling through Stochastic Differential Equations"
- arXiv: https://arxiv.org/abs/2011.13456
Summary
Functions
Build a score network s(x, t) for the Score SDE framework.
Denoising score matching loss.
Get the output size of a Score SDE model.
Calculate approximate parameter count.
Get recommended defaults.
VE-SDE noise schedule.
Compute the noise level sigma at time t for VE-SDE.
Compute the marginal distribution parameters at time t for VP-SDE.
VP-SDE noise schedule.
Types
@type build_opt() :: {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:num_layers, pos_integer()} | {:sde_type, :vp | :ve}
Options for build/1.
Functions
Build a score network s(x, t) for the Score SDE framework.
Options
:input_dim- Input feature dimension (required):hidden_size- Hidden dimension (default: 256):num_layers- Number of residual blocks (default: 4):sde_type- SDE variant::vpor:ve(default: :vp)
Returns
An Axon model that predicts the score s(x, t) = grad_x log p_t(x).
@spec dsm_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Denoising score matching loss.
The score at noise level sigma is: s(x_t) = -(x_t - x_0) / sigma^2 Train by minimizing: E[sigma^2 * ||s_theta(x_t, t) + (x_t - x_0)/sigma^2||^2]
Parameters
score_pred- Predicted score s_theta(x_t, t)noise- The noise added (epsilon)sigma- Noise standard deviation at this timestep
Returns
Scalar loss.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Score SDE model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count.
@spec recommended_defaults() :: keyword()
Get recommended defaults.
VE-SDE noise schedule.
sigma(t) = sigma_min * (sigma_max / sigma_min)^t
Options
:sigma_min- Minimum sigma (default: 0.01):sigma_max- Maximum sigma (default: 50.0)
Returns
Schedule map.
@spec ve_sigma(Nx.Tensor.t(), map()) :: Nx.Tensor.t()
Compute the noise level sigma at time t for VE-SDE.
sigma(t) = sigma_min * (sigma_max / sigma_min)^t
@spec vp_marginal(Nx.Tensor.t(), map()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Compute the marginal distribution parameters at time t for VP-SDE.
For VP-SDE: p(x_t | x_0) = N(x_t; alpha(t) x_0, sigma(t)^2 I) where alpha(t) = exp(-0.5 * integral(beta, 0, t))
VP-SDE noise schedule.
beta(t) = beta_min + t * (beta_max - beta_min)
Options
:beta_min- Minimum beta (default: 0.1):beta_max- Maximum beta (default: 20.0)
Returns
Schedule map with perturbation and marginal functions.