# `Edifice.Generative.FlowMatching`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/flow_matching.ex#L1)

Flow Matching: Action generation via continuous normalizing flows.

Implements Conditional Flow Matching (CFM) from "Flow Matching for Generative
Modeling" (Lipman et al., ICLR 2023). Learns a velocity field that transports
samples from noise to data via an ODE.

## Key Innovation: Optimal Transport Paths

Instead of diffusion's complex noise schedule, Flow Matching uses simple
linear interpolation (optimal transport path):

```
x_t = (1 - t) * x_0 + t * x_1    where x_0 ~ noise, x_1 ~ data
v_target = x_1 - x_0             (constant velocity along path)
```

Training minimizes: ||v_theta(x_t, t | obs) - v_target||^2

## Comparison with Diffusion

| Feature | Diffusion | Flow Matching |
|---------|-----------|---------------|
| Path | Stochastic (SDE) | Deterministic (ODE) |
| Schedule | Complex (beta schedule) | None needed |
| Training | Noise prediction | Velocity prediction |
| Inference | DDPM/DDIM sampling | ODE integration |
| Steps | 20-100+ typical | 10-20 often sufficient |

## Architecture

```
Observations [batch, obs_dim]
      |
      v
+-------------------------------------+
|  Observation Encoder                 |
+-------------------------------------+
      |
      v obs_embed
+-------------------------------------+
|  Velocity Network                    |
|  Input: (x_t, t, obs_embed)         |
|  Output: v_theta (velocity field)    |
+-------------------------------------+
      |
      v
Actions [batch, action_horizon, action_dim]
```

## Training

```elixir
# Forward process (create interpolated sample)
x_t = FlowMatching.interpolate(noise, actions, t)
target_velocity = actions - noise

# Predict velocity
pred_velocity = velocity_network(x_t, t, observations)

# MSE loss
loss = FlowMatching.velocity_loss(target_velocity, pred_velocity)
```

## Inference (ODE Integration)

```elixir
# Start from noise
x_0 = random_noise()

# Euler integration (or higher order)
for t <- 0..1 step dt:
  v = velocity_network(x_t, t, observations)
  x_{t+dt} = x_t + dt * v

# Final x_1 is the generated action
```

## Usage

    # Build flow matching model
    model = FlowMatching.build(
      obs_size: 287,
      action_dim: 64,
      action_horizon: 8
    )

    # Training
    loss = FlowMatching.compute_loss(
      params, predict_fn, observations, actions, noise, t
    )

    # Inference
    actions = FlowMatching.sample(
      params, predict_fn, observations,
      num_steps: 20, solver: :euler
    )

## References
- Flow Matching: https://arxiv.org/abs/2210.02747
- Conditional Flow Matching: https://arxiv.org/abs/2302.00482
- Rectified Flow: https://arxiv.org/abs/2209.03003

# `build_opt`

```elixir
@type build_opt() ::
  {:action_dim, pos_integer()}
  | {:action_horizon, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:obs_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Flow Matching model for action generation.

## Options
  - `:obs_size` - Size of observation embedding (required)
  - `:action_dim` - Dimension of action space (required)
  - `:action_horizon` - Number of actions per sequence (default: 8)
  - `:hidden_size` - Hidden dimension (default: 256)
  - `:num_layers` - Number of MLP layers (default: 4)

## Returns
  An Axon model that predicts velocity given (x_t, t, obs).

# `compute_loss`

```elixir
@spec compute_loss(
  map(),
  (map(), map() -&gt; Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  Nx.Tensor.t()
) :: Nx.Tensor.t()
```

Compute the complete training loss.

## Parameters
  - `params` - Model parameters
  - `predict_fn` - Velocity prediction function
  - `observations` - Conditioning observations [batch, obs_size]
  - `actions` - Target actions (x_1) [batch, action_horizon, action_dim]
  - `noise` - Source noise (x_0) [batch, action_horizon, action_dim]
  - `t` - Random timesteps [batch]

## Returns
  Scalar loss value.

# `default_action_horizon`

```elixir
@spec default_action_horizon() :: pos_integer()
```

Default action prediction horizon

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of network layers

# `default_num_steps`

```elixir
@spec default_num_steps() :: pos_integer()
```

Default number of ODE integration steps

# `default_solver`

```elixir
@spec default_solver() :: atom()
```

Default ODE solver

# `fast_inference_defaults`

```elixir
@spec fast_inference_defaults() :: keyword()
```

Get fast inference configuration.

# `generate_rectified_pairs`

```elixir
@spec generate_rectified_pairs(
  map(),
  (map(), map() -&gt; Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Generate training pairs for rectified flow.

Samples (noise, generated_action) pairs from a trained model
to create straighter paths.

# `interpolate`

```elixir
@spec interpolate(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Interpolate between noise (x_0) and data (x_1) at time t.

Uses optimal transport (linear) interpolation:
x_t = (1 - t) * x_0 + t * x_1

## Parameters
  - `x_0` - Source (noise) [batch, action_horizon, action_dim]
  - `x_1` - Target (data/actions) [batch, action_horizon, action_dim]
  - `t` - Time in [0, 1] [batch]

## Returns
  Interpolated x_t with same shape as inputs.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a flow matching model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a flow matching model.

# `quality_defaults`

```elixir
@spec quality_defaults() :: keyword()
```

Get high-quality configuration (more steps, better solver).

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for action generation.

# `rectified_loss`

```elixir
@spec rectified_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute rectified flow loss with distillation.

Rectified flow straightens the ODE paths by distilling from a trained
flow model, allowing even fewer integration steps.

This is a two-stage process:
1. Train standard flow matching
2. Generate (x_0, x_1) pairs by sampling, then retrain on straight paths

# `sample`

```elixir
@spec sample(
  map(),
  (map(), map() -&gt; Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) ::
  Nx.Tensor.t()
```

Sample actions by integrating the learned ODE.

Solves dx/dt = v_theta(x, t) from t=0 to t=1.

## Parameters
  - `params` - Model parameters
  - `predict_fn` - Velocity prediction function
  - `observations` - Conditioning [batch, obs_size]
  - `initial_noise` - Starting noise (x_0) [batch, action_horizon, action_dim]
  - `opts` - Options:
    - `:num_steps` - Integration steps (default: 20)
    - `:solver` - ODE solver: :euler, :midpoint, :rk4 (default: :euler)

## Returns
  Generated actions [batch, action_horizon, action_dim].

# `sample_guided`

```elixir
@spec sample_guided(
  map(),
  (map(), map() -&gt; Nx.Tensor.t()),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()
```

Sample with classifier-free guidance.

Interpolates between conditional and unconditional predictions:
v_guided = v_uncond + guidance_scale * (v_cond - v_uncond)

## Parameters
  - `:guidance_scale` - Strength of guidance (default: 1.0, no guidance)
  - `:uncond_observations` - Unconditional observations (zeros or learned)

# `target_velocity`

```elixir
@spec target_velocity(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the target velocity for training.

For optimal transport path, velocity is constant:
v_target = x_1 - x_0

## Parameters
  - `x_0` - Source (noise)
  - `x_1` - Target (data)

## Returns
  Target velocity (same shape as inputs).

# `velocity_loss`

```elixir
@spec velocity_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute velocity matching loss (MSE).

L = ||v_theta(x_t, t) - v_target||^2

## Parameters
  - `pred_velocity` - Predicted velocity from network
  - `target_velocity` - Ground truth velocity (x_1 - x_0)

## Returns
  Scalar loss value.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
