Edifice.Generative.FlowMatching (Edifice v0.2.0)

Copy Markdown View Source

Flow Matching: Action generation via continuous normalizing flows.

Implements Conditional Flow Matching (CFM) from "Flow Matching for Generative Modeling" (Lipman et al., ICLR 2023). Learns a velocity field that transports samples from noise to data via an ODE.

Key Innovation: Optimal Transport Paths

Instead of diffusion's complex noise schedule, Flow Matching uses simple linear interpolation (optimal transport path):

x_t = (1 - t) * x_0 + t * x_1    where x_0 ~ noise, x_1 ~ data
v_target = x_1 - x_0             (constant velocity along path)

Training minimizes: ||v_theta(x_t, t | obs) - v_target||^2

Comparison with Diffusion

FeatureDiffusionFlow Matching
PathStochastic (SDE)Deterministic (ODE)
ScheduleComplex (beta schedule)None needed
TrainingNoise predictionVelocity prediction
InferenceDDPM/DDIM samplingODE integration
Steps20-100+ typical10-20 often sufficient

Architecture

Observations [batch, obs_dim]
      |
      v
+-------------------------------------+
|  Observation Encoder                 |
+-------------------------------------+
      |
      v obs_embed
+-------------------------------------+
|  Velocity Network                    |
|  Input: (x_t, t, obs_embed)         |
|  Output: v_theta (velocity field)    |
+-------------------------------------+
      |
      v
Actions [batch, action_horizon, action_dim]

Training

# Forward process (create interpolated sample)
x_t = FlowMatching.interpolate(noise, actions, t)
target_velocity = actions - noise

# Predict velocity
pred_velocity = velocity_network(x_t, t, observations)

# MSE loss
loss = FlowMatching.velocity_loss(target_velocity, pred_velocity)

Inference (ODE Integration)

# Start from noise
x_0 = random_noise()

# Euler integration (or higher order)
for t <- 0..1 step dt:
  v = velocity_network(x_t, t, observations)
  x_{t+dt} = x_t + dt * v

# Final x_1 is the generated action

Usage

# Build flow matching model
model = FlowMatching.build(
  obs_size: 287,
  action_dim: 64,
  action_horizon: 8
)

# Training
loss = FlowMatching.compute_loss(
  params, predict_fn, observations, actions, noise, t
)

# Inference
actions = FlowMatching.sample(
  params, predict_fn, observations,
  num_steps: 20, solver: :euler
)

References

Summary

Types

Options for build/1.

Functions

Build a Flow Matching model for action generation.

Default action prediction horizon

Default hidden dimension

Default number of network layers

Default number of ODE integration steps

Default ODE solver

Get fast inference configuration.

Interpolate between noise (x_0) and data (x_1) at time t.

Get the output size of a flow matching model.

Calculate approximate parameter count for a flow matching model.

Get high-quality configuration (more steps, better solver).

Get recommended defaults for action generation.

Compute rectified flow loss with distillation.

Sample actions by integrating the learned ODE.

Compute the target velocity for training.

Compute velocity matching loss (MSE).

Types

build_opt()

@type build_opt() ::
  {:action_dim, pos_integer()}
  | {:action_horizon, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:obs_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Flow Matching model for action generation.

Options

  • :obs_size - Size of observation embedding (required)
  • :action_dim - Dimension of action space (required)
  • :action_horizon - Number of actions per sequence (default: 8)
  • :hidden_size - Hidden dimension (default: 256)
  • :num_layers - Number of MLP layers (default: 4)

Returns

An Axon model that predicts velocity given (x_t, t, obs).

compute_loss(params, predict_fn, observations, actions, noise, t)

@spec compute_loss(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t()
) :: Nx.Tensor.t()

Compute the complete training loss.

Parameters

  • params - Model parameters
  • predict_fn - Velocity prediction function
  • observations - Conditioning observations [batch, obs_size]
  • actions - Target actions (x_1) [batch, action_horizon, action_dim]
  • noise - Source noise (x_0) [batch, action_horizon, action_dim]
  • t - Random timesteps [batch]

Returns

Scalar loss value.

default_action_horizon()

@spec default_action_horizon() :: pos_integer()

Default action prediction horizon

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of network layers

default_num_steps()

@spec default_num_steps() :: pos_integer()

Default number of ODE integration steps

default_solver()

@spec default_solver() :: atom()

Default ODE solver

fast_inference_defaults()

@spec fast_inference_defaults() :: keyword()

Get fast inference configuration.

generate_rectified_pairs(params, predict_fn, observations, noise, opts \\ [])

@spec generate_rectified_pairs(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Generate training pairs for rectified flow.

Samples (noise, generated_action) pairs from a trained model to create straighter paths.

interpolate(x_0, x_1, t)

@spec interpolate(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Interpolate between noise (x_0) and data (x_1) at time t.

Uses optimal transport (linear) interpolation: x_t = (1 - t) x_0 + t x_1

Parameters

  • x_0 - Source (noise) [batch, action_horizon, action_dim]
  • x_1 - Target (data/actions) [batch, action_horizon, action_dim]
  • t - Time in [0, 1] [batch]

Returns

Interpolated x_t with same shape as inputs.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a flow matching model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a flow matching model.

quality_defaults()

@spec quality_defaults() :: keyword()

Get high-quality configuration (more steps, better solver).

rectified_loss(target_vel, pred_vel)

@spec rectified_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute rectified flow loss with distillation.

Rectified flow straightens the ODE paths by distilling from a trained flow model, allowing even fewer integration steps.

This is a two-stage process:

  1. Train standard flow matching
  2. Generate (x_0, x_1) pairs by sampling, then retrain on straight paths

sample(params, predict_fn, observations, initial_noise, opts \\ [])

@spec sample(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) ::
  Nx.Tensor.t()

Sample actions by integrating the learned ODE.

Solves dx/dt = v_theta(x, t) from t=0 to t=1.

Parameters

  • params - Model parameters
  • predict_fn - Velocity prediction function
  • observations - Conditioning [batch, obs_size]
  • initial_noise - Starting noise (x_0) [batch, action_horizon, action_dim]
  • opts - Options:
    • :num_steps - Integration steps (default: 20)
    • :solver - ODE solver: :euler, :midpoint, :rk4 (default: :euler)

Returns

Generated actions [batch, action_horizon, action_dim].

sample_guided(params, predict_fn, observations, initial_noise, opts \\ [])

@spec sample_guided(
  map(),
  (map(), map() -> Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Sample with classifier-free guidance.

Interpolates between conditional and unconditional predictions: v_guided = v_uncond + guidance_scale * (v_cond - v_uncond)

Parameters

  • :guidance_scale - Strength of guidance (default: 1.0, no guidance)
  • :uncond_observations - Unconditional observations (zeros or learned)

target_velocity(x_0, x_1)

@spec target_velocity(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the target velocity for training.

For optimal transport path, velocity is constant: v_target = x_1 - x_0

Parameters

  • x_0 - Source (noise)
  • x_1 - Target (data)

Returns

Target velocity (same shape as inputs).

velocity_loss(target_vel, pred_vel)

@spec velocity_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute velocity matching loss (MSE).

L = ||v_theta(x_t, t) - v_target||^2

Parameters

  • pred_velocity - Predicted velocity from network
  • target_velocity - Ground truth velocity (x_1 - x_0)

Returns

Scalar loss value.