SoFlow: Solution Flow Models for One-Step Generative Modeling.
Implements the SoFlow framework from "SoFlow: Solution Flow Models for One-Step Generative Modeling" (Luo et al., Princeton, Dec 2025). Instead of learning a velocity field and integrating an ODE at inference (like standard Flow Matching), SoFlow learns the ODE's solution function directly, enabling high-quality one-step generation.
Key Innovation: Solution Function
Standard Flow Matching learns v(x_t, t) and requires multi-step ODE
integration at inference. SoFlow learns f(x_t, t, s) — the function
that maps state at time t directly to state at time s.
Flow Matching (multi-step):
x_0 --v(x,0)--> x_{dt} --v(x,dt)--> ... --v(x,1-dt)--> x_1
SoFlow (one-step):
x_0 --f(x_0, 0, 1)--> x_1Euler Parameterization
The solution function is parameterized as:
f_theta(x_t, t, s) = x_t + (s - t) * F_theta(x_t, t, s)
This automatically satisfies the identity f(x_t, t, t) = x_t.
The network F_theta predicts a "normalized velocity" conditioned on
both current time t and target time s.
Training Loss
Combined from two components:
- Flow Matching Loss (L_FM): Ensures network derivatives match true
velocity at the
t = sboundary - Solution Consistency Loss (L_SCM): Enforces self-consistency of the solution function across different time intervals, using an EMA target network
L = lambda * L_FM + (1 - lambda) * L_SCM
Architecture
Same as Flow Matching, but the velocity network takes two time
inputs (current t and target s) instead of one:
Inputs: (x_t, t, s, observations)
|
v
[Time Embeddings] t_embed + s_embed + obs_embed + x_embed
|
v
[Residual MLP Blocks x num_layers]
|
v
F_theta(x_t, t, s) -- "normalized velocity"One-step inference: x_generated = x_noise + F_theta(x_noise, 0, 1)
Comparison
| Method | Steps | Distillation? | Quality |
|---|---|---|---|
| Flow Matching | 20-50 | No | High |
| Consistency Model | 1 | Optional | Medium |
| SoFlow | 1-2 | No | High |
Usage
model = SoFlow.build(
obs_size: 287,
action_dim: 64,
action_horizon: 8
)
# One-step generation
x_generated = SoFlow.one_step_sample(params, predict_fn, observations, noise)
# Two-step refinement
x_refined = SoFlow.multi_step_sample(params, predict_fn, observations, noise, steps: 2)References
Summary
Functions
Build a SoFlow model for one-step generative modeling.
Combined SoFlow training loss.
Solution Consistency loss component (L_SCM).
Apply the Euler parameterization to get the solution function value.
Flow Matching loss component (L_FM).
Linear interpolation between noise and data.
Multi-step generation for improved quality.
One-step generation using the solution function.
Get the output size of a SoFlow model.
Get recommended defaults.
Types
@type build_opt() :: {:action_dim, pos_integer()} | {:action_horizon, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:obs_size, pos_integer()}
Options for build/1.
Functions
Build a SoFlow model for one-step generative modeling.
The key difference from Flow Matching: this network takes two time
inputs — current time t and target time s — enabling direct solution
function learning.
Options
:obs_size- Size of observation/conditioning embedding (required):action_dim- Dimension of action/data space (required):action_horizon- Number of actions per sequence (default: 8):hidden_size- Hidden dimension (default: 256):num_layers- Number of MLP layers (default: 4)
Returns
An Axon model: (x_t, t, s, observations) -> F_theta(x_t, t, s)
@spec combined_loss(Nx.Tensor.t(), Nx.Tensor.t(), float()) :: Nx.Tensor.t()
Combined SoFlow training loss.
L = lambda * L_FM + (1 - lambda) * L_SCM
Parameters
fm_loss- Flow matching lossscm_loss- Solution consistency losslambda- Mixing ratio (default: 0.5)
@spec consistency_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Solution Consistency loss component (L_SCM).
Enforces that the solution function is self-consistent: starting from different points on the same trajectory should yield the same endpoint.
Uses a Taylor step to advance from t to l, then checks consistency:
f(x_t, t, s) should equal f(x_t + v*(l-t), l, s)
The target uses an EMA (stop-gradient) network for stability.
Parameters
f_current- f_theta(x_t, t, s) = x_t + (s-t) * F_theta(x_t, t, s)f_target- stop_grad(f_ema(x_stepped, l, s)) from EMA network
@spec euler_parameterize(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Apply the Euler parameterization to get the solution function value.
f(x_t, t, s) = x_t + (s - t) * F_theta(x_t, t, s)
Parameters
x_t- Current state [batch, horizon, dim]f_theta- Network output (normalized velocity) [batch, horizon, dim]t- Current time [batch]s- Target time [batch]
@spec flow_matching_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Flow Matching loss component (L_FM).
Ensures the network's behavior at the t = s boundary matches the true
velocity. This grounds the solution function to the underlying ODE.
Parameters
f_theta_at_t_t- F_theta(x_t, t, t) (network output when s = t)velocity_target- True velocity: alpha_t' x_0 + beta_t' x_1
For linear interpolation (alpha_t = 1-t, beta_t = t): velocity_target = x_1 - x_0
@spec interpolate(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Linear interpolation between noise and data.
x_t = (1 - t) * x_0 + t * x_1
@spec multi_step_sample( map(), (map(), map() -> Nx.Tensor.t()), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Multi-step generation for improved quality.
Divides [0, 1] into N equal segments and applies the solution function sequentially on each segment.
Parameters
params- Model parameterspredict_fn- The compiled prediction functionobservations- Conditioning [batch, obs_size]noise- Initial noise [batch, action_horizon, action_dim]opts::steps- Number of steps (default: 2)
@spec one_step_sample( map(), (map(), map() -> Nx.Tensor.t()), Nx.Tensor.t(), Nx.Tensor.t() ) :: Nx.Tensor.t()
One-step generation using the solution function.
x_generated = x_noise + F_theta(x_noise, 0, 1)
Parameters
params- Model parameterspredict_fn- The compiled prediction functionobservations- Conditioning [batch, obs_size]noise- Initial noise [batch, action_horizon, action_dim]
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a SoFlow model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.