Named fused kernels for MCMC hot paths.
Each function emits an Nx.Defn.Expr.optional/3 IR node whose name
matches a callback on Nx.Vulkan.Backend. Under Nx.Vulkan the
evaluator dispatches one fused shader; under any other backend the
defn fallback runs and produces a mathematically-equivalent result.
Same pattern as Emily.Fast.
Why this exists
We previously built Nx.Vulkan.Compiler to walk the IR and detect
fusable patterns automatically. That works for narrow cases but
doesn't scale: each new pattern is more compiler code, false
negatives are silent, and the matched shapes drift from real exmc
usage. Naming the kernels at call sites makes the intent explicit
and the dispatch deterministic. The fallback ensures cross-backend
correctness (EXLA, BinaryBackend, EMLX).
How to use
Inside a defn or any Nx.Defn.jit-traced function:
defn leapfrog_step(q, eps, p, grad) do
q_new = Nx.Vulkan.Fast.leapfrog_position(q, eps, p)
p_new = Nx.Vulkan.Fast.momentum_step(p, eps, grad)
{q_new, p_new}
endUnder Nx.Vulkan.Backend each Fast call collapses to one
Nx.Vulkan.fused_chain_4 dispatch (single shader). Under any other
backend the defn fallback runs the composed primitives.
Adding kernels
Each kernel is two functions:
- The public entry — emits
Nx.Defn.Expr.optional/3. - A private
_fallback— defn-style composed Nx ops.
Plus one matching callback in Nx.Vulkan.Backend. Total ~30 lines
per kernel; compare to ~100 lines + tests for an IR-detector
pattern.
Summary
Functions
Apply diagonal mass-matrix inverse: p * inv_mass. Trivial as a
fused kernel (one binary op), but named for symmetry — a future
shader could combine it with adjacent ops in the leapfrog without
changing call sites.
Kinetic energy: 0.5 * sum(p² * inv_mass). Reduces to a scalar.
Used in NUTS for the joint log-probability:
joint_logp = log_prob - kinetic_energy(p, inv_mass).
Half-step momentum update: p + half_eps * grad. Used at the start
and end of every leapfrog iteration in the standard symplectic
integrator. half_eps is eps / 2 precomputed by the caller.
Position update: q + eps * p. The dominant elementwise body in
every NUTS leapfrog.
Full-step momentum update: p + eps * grad. Same shape as the
half-step but kept distinct to signal the caller's intent.
Normal log-density: -0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π).
Output shape matches x. The MCMC distribution-density hot path.
Functions
Apply diagonal mass-matrix inverse: p * inv_mass. Trivial as a
fused kernel (one binary op), but named for symmetry — a future
shader could combine it with adjacent ops in the leapfrog without
changing call sites.
Kinetic energy: 0.5 * sum(p² * inv_mass). Reduces to a scalar.
Used in NUTS for the joint log-probability:
joint_logp = log_prob - kinetic_energy(p, inv_mass).
Under Nx.Vulkan dispatches kinetic_energy.spv (one shader: square +
multiply + per-workgroup reduce + 0.5 multiplier baked in). The
shader produces partial sums; the backend callback sums them on the
host and returns a scalar.
Half-step momentum update: p + half_eps * grad. Used at the start
and end of every leapfrog iteration in the standard symplectic
integrator. half_eps is eps / 2 precomputed by the caller.
Position update: q + eps * p. The dominant elementwise body in
every NUTS leapfrog.
Examples
iex> q = Nx.tensor([1.0, 2.0])
iex> eps = Nx.tensor([0.5, 0.5])
iex> p = Nx.tensor([2.0, 4.0])
iex> Nx.Vulkan.Fast.leapfrog_position(q, eps, p) |> Nx.to_flat_list()
[2.0, 4.0]
Full-step momentum update: p + eps * grad. Same shape as the
half-step but kept distinct to signal the caller's intent.
Normal log-density: -0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π).
Output shape matches x. The MCMC distribution-density hot path.
Under Nx.Vulkan dispatches normal_logpdf.spv (one fused shader)
instead of the five separate Nx ops the fallback emits.