Flow Matching: Action generation via continuous normalizing flows.
Implements Conditional Flow Matching (CFM) from "Flow Matching for Generative Modeling" (Lipman et al., ICLR 2023). Learns a velocity field that transports samples from noise to data via an ODE.
Key Innovation: Optimal Transport Paths
Instead of diffusion's complex noise schedule, Flow Matching uses simple linear interpolation (optimal transport path):
x_t = (1 - t) * x_0 + t * x_1 where x_0 ~ noise, x_1 ~ data
v_target = x_1 - x_0 (constant velocity along path)Training minimizes: ||v_theta(x_t, t | obs) - v_target||^2
Comparison with Diffusion
| Feature | Diffusion | Flow Matching |
|---|---|---|
| Path | Stochastic (SDE) | Deterministic (ODE) |
| Schedule | Complex (beta schedule) | None needed |
| Training | Noise prediction | Velocity prediction |
| Inference | DDPM/DDIM sampling | ODE integration |
| Steps | 20-100+ typical | 10-20 often sufficient |
Architecture
Observations [batch, obs_dim]
|
v
+-------------------------------------+
| Observation Encoder |
+-------------------------------------+
|
v obs_embed
+-------------------------------------+
| Velocity Network |
| Input: (x_t, t, obs_embed) |
| Output: v_theta (velocity field) |
+-------------------------------------+
|
v
Actions [batch, action_horizon, action_dim]Training
# Forward process (create interpolated sample)
x_t = FlowMatching.interpolate(noise, actions, t)
target_velocity = actions - noise
# Predict velocity
pred_velocity = velocity_network(x_t, t, observations)
# MSE loss
loss = FlowMatching.velocity_loss(target_velocity, pred_velocity)Inference (ODE Integration)
# Start from noise
x_0 = random_noise()
# Euler integration (or higher order)
for t <- 0..1 step dt:
v = velocity_network(x_t, t, observations)
x_{t+dt} = x_t + dt * v
# Final x_1 is the generated actionUsage
# Build flow matching model
model = FlowMatching.build(
obs_size: 287,
action_dim: 64,
action_horizon: 8
)
# Training
loss = FlowMatching.compute_loss(
params, predict_fn, observations, actions, noise, t
)
# Inference
actions = FlowMatching.sample(
params, predict_fn, observations,
num_steps: 20, solver: :euler
)References
- Flow Matching: https://arxiv.org/abs/2210.02747
- Conditional Flow Matching: https://arxiv.org/abs/2302.00482
- Rectified Flow: https://arxiv.org/abs/2209.03003
Summary
Functions
Build a Flow Matching model for action generation.
Compute the complete training loss.
Default action prediction horizon
Default hidden dimension
Default number of network layers
Default number of ODE integration steps
Default ODE solver
Get fast inference configuration.
Generate training pairs for rectified flow.
Interpolate between noise (x_0) and data (x_1) at time t.
Get the output size of a flow matching model.
Calculate approximate parameter count for a flow matching model.
Get high-quality configuration (more steps, better solver).
Get recommended defaults for action generation.
Compute rectified flow loss with distillation.
Sample actions by integrating the learned ODE.
Sample with classifier-free guidance.
Compute the target velocity for training.
Compute velocity matching loss (MSE).
Types
@type build_opt() :: {:action_dim, pos_integer()} | {:action_horizon, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:obs_size, pos_integer()}
Options for build/1.
Functions
Build a Flow Matching model for action generation.
Options
:obs_size- Size of observation embedding (required):action_dim- Dimension of action space (required):action_horizon- Number of actions per sequence (default: 8):hidden_size- Hidden dimension (default: 256):num_layers- Number of MLP layers (default: 4)
Returns
An Axon model that predicts velocity given (x_t, t, obs).
@spec compute_loss( map(), (map(), map() -> Nx.Tensor.t()), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() ) :: Nx.Tensor.t()
Compute the complete training loss.
Parameters
params- Model parameterspredict_fn- Velocity prediction functionobservations- Conditioning observations [batch, obs_size]actions- Target actions (x_1) [batch, action_horizon, action_dim]noise- Source noise (x_0) [batch, action_horizon, action_dim]t- Random timesteps [batch]
Returns
Scalar loss value.
@spec default_action_horizon() :: pos_integer()
Default action prediction horizon
@spec default_num_layers() :: pos_integer()
Default number of network layers
@spec default_num_steps() :: pos_integer()
Default number of ODE integration steps
@spec default_solver() :: atom()
Default ODE solver
@spec fast_inference_defaults() :: keyword()
Get fast inference configuration.
@spec generate_rectified_pairs( map(), (map(), map() -> Nx.Tensor.t()), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Generate training pairs for rectified flow.
Samples (noise, generated_action) pairs from a trained model to create straighter paths.
@spec interpolate(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Interpolate between noise (x_0) and data (x_1) at time t.
Uses optimal transport (linear) interpolation: x_t = (1 - t) x_0 + t x_1
Parameters
x_0- Source (noise) [batch, action_horizon, action_dim]x_1- Target (data/actions) [batch, action_horizon, action_dim]t- Time in [0, 1] [batch]
Returns
Interpolated x_t with same shape as inputs.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a flow matching model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a flow matching model.
@spec quality_defaults() :: keyword()
Get high-quality configuration (more steps, better solver).
@spec recommended_defaults() :: keyword()
Get recommended defaults for action generation.
@spec rectified_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute rectified flow loss with distillation.
Rectified flow straightens the ODE paths by distilling from a trained flow model, allowing even fewer integration steps.
This is a two-stage process:
- Train standard flow matching
- Generate (x_0, x_1) pairs by sampling, then retrain on straight paths
@spec sample( map(), (map(), map() -> Nx.Tensor.t()), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Sample actions by integrating the learned ODE.
Solves dx/dt = v_theta(x, t) from t=0 to t=1.
Parameters
params- Model parameterspredict_fn- Velocity prediction functionobservations- Conditioning [batch, obs_size]initial_noise- Starting noise (x_0) [batch, action_horizon, action_dim]opts- Options::num_steps- Integration steps (default: 20):solver- ODE solver: :euler, :midpoint, :rk4 (default: :euler)
Returns
Generated actions [batch, action_horizon, action_dim].
@spec sample_guided( map(), (map(), map() -> Nx.Tensor.t()), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Sample with classifier-free guidance.
Interpolates between conditional and unconditional predictions: v_guided = v_uncond + guidance_scale * (v_cond - v_uncond)
Parameters
:guidance_scale- Strength of guidance (default: 1.0, no guidance):uncond_observations- Unconditional observations (zeros or learned)
@spec target_velocity(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the target velocity for training.
For optimal transport path, velocity is constant: v_target = x_1 - x_0
Parameters
x_0- Source (noise)x_1- Target (data)
Returns
Target velocity (same shape as inputs).
@spec velocity_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute velocity matching loss (MSE).
L = ||v_theta(x_t, t) - v_target||^2
Parameters
pred_velocity- Predicted velocity from networktarget_velocity- Ground truth velocity (x_1 - x_0)
Returns
Scalar loss value.