# `Nx.Vulkan.Fast`
[🔗](https://github.com/borodark/nx_vulkan/blob/main/lib/nx_vulkan/fast.ex#L1)

Named fused kernels for MCMC hot paths.

Each function emits an `Nx.Defn.Expr.optional/3` IR node whose name
matches a callback on `Nx.Vulkan.Backend`. Under Nx.Vulkan the
evaluator dispatches one fused shader; under any other backend the
defn fallback runs and produces a mathematically-equivalent result.
Same pattern as `Emily.Fast`.

## Why this exists

We previously built `Nx.Vulkan.Compiler` to walk the IR and detect
fusable patterns automatically. That works for narrow cases but
doesn't scale: each new pattern is more compiler code, false
negatives are silent, and the matched shapes drift from real exmc
usage. Naming the kernels at call sites makes the intent explicit
and the dispatch deterministic. The fallback ensures cross-backend
correctness (EXLA, BinaryBackend, EMLX).

## How to use

Inside a `defn` or any Nx.Defn.jit-traced function:

    defn leapfrog_step(q, eps, p, grad) do
      q_new = Nx.Vulkan.Fast.leapfrog_position(q, eps, p)
      p_new = Nx.Vulkan.Fast.momentum_step(p, eps, grad)
      {q_new, p_new}
    end

Under Nx.Vulkan.Backend each Fast call collapses to one
`Nx.Vulkan.fused_chain_4` dispatch (single shader). Under any other
backend the defn fallback runs the composed primitives.

## Adding kernels

Each kernel is two functions:

  1. The public entry — emits `Nx.Defn.Expr.optional/3`.
  2. A private `_fallback` — defn-style composed Nx ops.

Plus one matching callback in `Nx.Vulkan.Backend`. Total ~30 lines
per kernel; compare to ~100 lines + tests for an IR-detector
pattern.

# `inv_mass_apply`

```elixir
@spec inv_mass_apply(Nx.t(), Nx.t()) :: Nx.t()
```

Apply diagonal mass-matrix inverse: `p * inv_mass`. Trivial as a
fused kernel (one binary op), but named for symmetry — a future
shader could combine it with adjacent ops in the leapfrog without
changing call sites.

# `kinetic_energy`

```elixir
@spec kinetic_energy(Nx.t(), Nx.t()) :: Nx.t()
```

Kinetic energy: `0.5 * sum(p² * inv_mass)`. Reduces to a scalar.
Used in NUTS for the joint log-probability:
`joint_logp = log_prob - kinetic_energy(p, inv_mass)`.

Under Nx.Vulkan dispatches `kinetic_energy.spv` (one shader: square +
multiply + per-workgroup reduce + 0.5 multiplier baked in). The
shader produces partial sums; the backend callback sums them on the
host and returns a scalar.

# `leapfrog_momentum_half`

```elixir
@spec leapfrog_momentum_half(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()
```

Half-step momentum update: `p + half_eps * grad`. Used at the start
and end of every leapfrog iteration in the standard symplectic
integrator. `half_eps` is `eps / 2` precomputed by the caller.

# `leapfrog_position`

```elixir
@spec leapfrog_position(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()
```

Position update: `q + eps * p`. The dominant elementwise body in
every NUTS leapfrog.

## Examples

    iex> q = Nx.tensor([1.0, 2.0])
    iex> eps = Nx.tensor([0.5, 0.5])
    iex> p = Nx.tensor([2.0, 4.0])
    iex> Nx.Vulkan.Fast.leapfrog_position(q, eps, p) |> Nx.to_flat_list()
    [2.0, 4.0]

# `momentum_step`

```elixir
@spec momentum_step(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()
```

Full-step momentum update: `p + eps * grad`. Same shape as the
half-step but kept distinct to signal the caller's intent.

# `normal_logpdf`

```elixir
@spec normal_logpdf(Nx.t(), Nx.t(), Nx.t()) :: Nx.t()
```

Normal log-density: `-0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π)`.
Output shape matches `x`. The MCMC distribution-density hot path.

Under Nx.Vulkan dispatches `normal_logpdf.spv` (one fused shader)
instead of the five separate Nx ops the fallback emits.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
