Edifice.Generative.SoFlow (Edifice v0.2.0)

Copy Markdown View Source

SoFlow: Solution Flow Models for One-Step Generative Modeling.

Implements the SoFlow framework from "SoFlow: Solution Flow Models for One-Step Generative Modeling" (Luo et al., Princeton, Dec 2025). Instead of learning a velocity field and integrating an ODE at inference (like standard Flow Matching), SoFlow learns the ODE's solution function directly, enabling high-quality one-step generation.

Key Innovation: Solution Function

Standard Flow Matching learns v(x_t, t) and requires multi-step ODE integration at inference. SoFlow learns f(x_t, t, s) — the function that maps state at time t directly to state at time s.

Flow Matching (multi-step):
  x_0 --v(x,0)--> x_{dt} --v(x,dt)--> ... --v(x,1-dt)--> x_1

SoFlow (one-step):
  x_0 --f(x_0, 0, 1)--> x_1

Euler Parameterization

The solution function is parameterized as: f_theta(x_t, t, s) = x_t + (s - t) * F_theta(x_t, t, s)

This automatically satisfies the identity f(x_t, t, t) = x_t. The network F_theta predicts a "normalized velocity" conditioned on both current time t and target time s.

Training Loss

Combined from two components:

  1. Flow Matching Loss (L_FM): Ensures network derivatives match true velocity at the t = s boundary
  2. Solution Consistency Loss (L_SCM): Enforces self-consistency of the solution function across different time intervals, using an EMA target network

L = lambda * L_FM + (1 - lambda) * L_SCM

Architecture

Same as Flow Matching, but the velocity network takes two time inputs (current t and target s) instead of one:

Inputs: (x_t, t, s, observations)
      |
      v
[Time Embeddings] t_embed + s_embed + obs_embed + x_embed
      |
      v
[Residual MLP Blocks x num_layers]
      |
      v
F_theta(x_t, t, s) -- "normalized velocity"

One-step inference: x_generated = x_noise + F_theta(x_noise, 0, 1)

Comparison

MethodStepsDistillation?Quality
Flow Matching20-50NoHigh
Consistency Model1OptionalMedium
SoFlow1-2NoHigh

Usage

model = SoFlow.build(
  obs_size: 287,
  action_dim: 64,
  action_horizon: 8
)

# One-step generation
x_generated = SoFlow.one_step_sample(params, predict_fn, observations, noise)

# Two-step refinement
x_refined = SoFlow.multi_step_sample(params, predict_fn, observations, noise, steps: 2)

References

Summary

Types

Options for build/1.

Functions

Build a SoFlow model for one-step generative modeling.

Combined SoFlow training loss.

Solution Consistency loss component (L_SCM).

Apply the Euler parameterization to get the solution function value.

Flow Matching loss component (L_FM).

Linear interpolation between noise and data.

One-step generation using the solution function.

Get the output size of a SoFlow model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:action_dim, pos_integer()}
  | {:action_horizon, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:obs_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a SoFlow model for one-step generative modeling.

The key difference from Flow Matching: this network takes two time inputs — current time t and target time s — enabling direct solution function learning.

Options

  • :obs_size - Size of observation/conditioning embedding (required)
  • :action_dim - Dimension of action/data space (required)
  • :action_horizon - Number of actions per sequence (default: 8)
  • :hidden_size - Hidden dimension (default: 256)
  • :num_layers - Number of MLP layers (default: 4)

Returns

An Axon model: (x_t, t, s, observations) -> F_theta(x_t, t, s)

combined_loss(fm_loss, scm_loss, lambda \\ 0.5)

@spec combined_loss(Nx.Tensor.t(), Nx.Tensor.t(), float()) :: Nx.Tensor.t()

Combined SoFlow training loss.

L = lambda * L_FM + (1 - lambda) * L_SCM

Parameters

  • fm_loss - Flow matching loss
  • scm_loss - Solution consistency loss
  • lambda - Mixing ratio (default: 0.5)

consistency_loss(f_current, f_target)

@spec consistency_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Solution Consistency loss component (L_SCM).

Enforces that the solution function is self-consistent: starting from different points on the same trajectory should yield the same endpoint.

Uses a Taylor step to advance from t to l, then checks consistency: f(x_t, t, s) should equal f(x_t + v*(l-t), l, s)

The target uses an EMA (stop-gradient) network for stability.

Parameters

  • f_current - f_theta(x_t, t, s) = x_t + (s-t) * F_theta(x_t, t, s)
  • f_target - stop_grad(f_ema(x_stepped, l, s)) from EMA network

euler_parameterize(x_t, f_theta, t, s)

@spec euler_parameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  Nx.Tensor.t()

Apply the Euler parameterization to get the solution function value.

f(x_t, t, s) = x_t + (s - t) * F_theta(x_t, t, s)

Parameters

  • x_t - Current state [batch, horizon, dim]
  • f_theta - Network output (normalized velocity) [batch, horizon, dim]
  • t - Current time [batch]
  • s - Target time [batch]

flow_matching_loss(f_theta_at_t_t, velocity_target)

@spec flow_matching_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Flow Matching loss component (L_FM).

Ensures the network's behavior at the t = s boundary matches the true velocity. This grounds the solution function to the underlying ODE.

Parameters

  • f_theta_at_t_t - F_theta(x_t, t, t) (network output when s = t)
  • velocity_target - True velocity: alpha_t' x_0 + beta_t' x_1

For linear interpolation (alpha_t = 1-t, beta_t = t): velocity_target = x_1 - x_0

interpolate(x_0, x_1, t)

@spec interpolate(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Linear interpolation between noise and data.

x_t = (1 - t) * x_0 + t * x_1

multi_step_sample(params, predict_fn, observations, noise, opts \\ [])

@spec multi_step_sample(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Multi-step generation for improved quality.

Divides [0, 1] into N equal segments and applies the solution function sequentially on each segment.

Parameters

  • params - Model parameters
  • predict_fn - The compiled prediction function
  • observations - Conditioning [batch, obs_size]
  • noise - Initial noise [batch, action_horizon, action_dim]
  • opts:
    • :steps - Number of steps (default: 2)

one_step_sample(params, predict_fn, observations, noise)

@spec one_step_sample(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t()
) ::
  Nx.Tensor.t()

One-step generation using the solution function.

x_generated = x_noise + F_theta(x_noise, 0, 1)

Parameters

  • params - Model parameters
  • predict_fn - The compiled prediction function
  • observations - Conditioning [batch, obs_size]
  • noise - Initial noise [batch, action_horizon, action_dim]

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a SoFlow model.