Nx.Vulkan.Fast (nx_vulkan v0.1.0)

Copy Markdown View Source

Named fused kernels for MCMC hot paths.

Each function emits an Nx.Defn.Expr.optional/3 IR node whose name matches a callback on Nx.Vulkan.Backend. Under Nx.Vulkan the evaluator dispatches one fused shader; under any other backend the defn fallback runs and produces a mathematically-equivalent result. Same pattern as Emily.Fast.

Why this exists

We previously built Nx.Vulkan.Compiler to walk the IR and detect fusable patterns automatically. That works for narrow cases but doesn't scale: each new pattern is more compiler code, false negatives are silent, and the matched shapes drift from real exmc usage. Naming the kernels at call sites makes the intent explicit and the dispatch deterministic. The fallback ensures cross-backend correctness (EXLA, BinaryBackend, EMLX).

How to use

Inside a defn or any Nx.Defn.jit-traced function:

defn leapfrog_step(q, eps, p, grad) do
  q_new = Nx.Vulkan.Fast.leapfrog_position(q, eps, p)
  p_new = Nx.Vulkan.Fast.momentum_step(p, eps, grad)
  {q_new, p_new}
end

Under Nx.Vulkan.Backend each Fast call collapses to one Nx.Vulkan.fused_chain_4 dispatch (single shader). Under any other backend the defn fallback runs the composed primitives.

Adding kernels

Each kernel is two functions:

  1. The public entry — emits Nx.Defn.Expr.optional/3.
  2. A private _fallback — defn-style composed Nx ops.

Plus one matching callback in Nx.Vulkan.Backend. Total ~30 lines per kernel; compare to ~100 lines + tests for an IR-detector pattern.

Summary

Functions

Apply diagonal mass-matrix inverse: p * inv_mass. Trivial as a fused kernel (one binary op), but named for symmetry — a future shader could combine it with adjacent ops in the leapfrog without changing call sites.

Kinetic energy: 0.5 * sum(p² * inv_mass). Reduces to a scalar. Used in NUTS for the joint log-probability: joint_logp = log_prob - kinetic_energy(p, inv_mass).

Half-step momentum update: p + half_eps * grad. Used at the start and end of every leapfrog iteration in the standard symplectic integrator. half_eps is eps / 2 precomputed by the caller.

Position update: q + eps * p. The dominant elementwise body in every NUTS leapfrog.

Full-step momentum update: p + eps * grad. Same shape as the half-step but kept distinct to signal the caller's intent.

Normal log-density: -0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π). Output shape matches x. The MCMC distribution-density hot path.

Functions

inv_mass_apply(p, inv_mass)

@spec inv_mass_apply(Nx.t(), Nx.t()) :: Nx.t()

Apply diagonal mass-matrix inverse: p * inv_mass. Trivial as a fused kernel (one binary op), but named for symmetry — a future shader could combine it with adjacent ops in the leapfrog without changing call sites.

kinetic_energy(p, inv_mass)

@spec kinetic_energy(Nx.t(), Nx.t()) :: Nx.t()

Kinetic energy: 0.5 * sum(p² * inv_mass). Reduces to a scalar. Used in NUTS for the joint log-probability: joint_logp = log_prob - kinetic_energy(p, inv_mass).

Under Nx.Vulkan dispatches kinetic_energy.spv (one shader: square + multiply + per-workgroup reduce + 0.5 multiplier baked in). The shader produces partial sums; the backend callback sums them on the host and returns a scalar.

leapfrog_momentum_half(p, half_eps, grad)

@spec leapfrog_momentum_half(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()

Half-step momentum update: p + half_eps * grad. Used at the start and end of every leapfrog iteration in the standard symplectic integrator. half_eps is eps / 2 precomputed by the caller.

leapfrog_position(q, eps, p)

@spec leapfrog_position(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()

Position update: q + eps * p. The dominant elementwise body in every NUTS leapfrog.

Examples

iex> q = Nx.tensor([1.0, 2.0])
iex> eps = Nx.tensor([0.5, 0.5])
iex> p = Nx.tensor([2.0, 4.0])
iex> Nx.Vulkan.Fast.leapfrog_position(q, eps, p) |> Nx.to_flat_list()
[2.0, 4.0]

momentum_step(p, eps, grad)

@spec momentum_step(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()

Full-step momentum update: p + eps * grad. Same shape as the half-step but kept distinct to signal the caller's intent.

normal_logpdf(x, mu, sigma)

@spec normal_logpdf(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()

Normal log-density: -0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π). Output shape matches x. The MCMC distribution-density hot path.

Under Nx.Vulkan dispatches normal_logpdf.spv (one fused shader) instead of the five separate Nx ops the fallback emits.