# `Nx.Vulkan`
[🔗](https://github.com/borodark/nx_vulkan/blob/main/lib/nx_vulkan.ex#L1)

Nx tensor backend on Vulkan compute.

Wraps Spirit's Vulkan compute kernels (elementwise, reductions,
matmul, random) as an `Nx.Backend`. Works on FreeBSD with NVIDIA
hardware where CUDA does not. Same backend code runs on Linux,
macOS (via MoltenVK), and any Vulkan-capable GPU.

## Status

v0.0.1 — bootstrap. The plan in `PLAN.md` lays out the 10-milestone
path to v0.1. This release just initializes Vulkan and reports
which physical device was selected; tensor types and operators
land in subsequent commits.

## Usage (target)

    iex> Nx.Vulkan.init()
    :ok

    iex> Nx.tensor([1.0, 2.0, 3.0], backend: Nx.Vulkan.Backend)
    ...

    iex> Nx.Defn.default_options(default_backend: Nx.Vulkan.Backend)

## Files

  * `lib/nx_vulkan.ex`           - this module (top-level API)
  * `lib/nx_vulkan/native.ex`    - Rustler NIF binding (skeleton)
  * `lib/nx_vulkan/backend.ex`   - `Nx.Backend` impl (TBD)
  * `native/nx_vulkan_native/`   - Rust NIF crate
  * `c_src/`                     - extern "C" shim into spirit's
                                   C++ Vulkan backend

# `abs`

Elementwise `abs` of a GPU tensor.

# `add`

Elementwise `add` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `apply_binary_broadcast`

Dispatch the broadcast variant of an elementwise binary op. `op` is
one of the binary atom keys in `@ops_binary`, ndim ≤ 4. Stride of 0
on an axis means broadcast on that axis. `out_shape`, `a_strides`,
`b_strides` are lists; the helper pads to length 4.

Use `Nx.Vulkan.broadcast_strides/2` to compute strides from a source
shape against the output shape.

# `apply_binary_broadcast_f64`

f64 broadcast elementwise binary; dispatches elementwise_binary_broadcast_f64.spv.

# `apply_binary_f64`

f64 elementwise binary; dispatches elementwise_binary_f64.spv.

# `apply_unary_f64`

f64 elementwise unary; dispatches elementwise_unary_f64.spv.

# `broadcast_strides`

Per-axis strides for broadcasting `src_shape` to `out_shape`.

Returns a length-4 list (zero-padded). Stride is 0 on a broadcast axis
(size 1 in `src` but >1 in `out`); otherwise it's the row-major
product of trailing source dims.

    iex> Nx.Vulkan.broadcast_strides({1, 4}, {3, 4})
    [0, 1, 0, 0]
    iex> Nx.Vulkan.broadcast_strides({2, 1}, {2, 4})
    [1, 0, 0, 0]

# `byte_size`

Byte size of an uploaded tensor (in bytes).

# `cast_f32_to_f64`

Cast f32 tensor → f64 (allocates 8-byte output).

# `cast_f64_to_f32`

Cast f64 tensor → f32 (allocates 4-byte output).

# `ceil`

Elementwise `ceil` of a GPU tensor.

# `clip`

Clip every element to `[low, high]`. Implemented as
`max(low, min(high, a))` with broadcasted scalar tensors.
Replaceable with a single-shader `clip.comp` once the broadcast
story matures (currently materializes scalars to N-element
buffers).

# `device_name`

Returns the name of the selected physical device, or `nil` if
`init/0` has not been called.

# `divide`

Elementwise `divide` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `download_binary`

Download as a raw binary (caller does the unpack).

# `download_f32`

Download a GPU buffer back into a list of f32 values. `n_elements`
must match what was uploaded.

Non-finite values (NaN, +Inf, -Inf) are returned as the atoms
`:nan`, `:infinity`, `:neg_infinity`. Erlang's float pattern
`<<x::float-32-native>>` rejects these bit patterns; we decode
the raw 32-bit pattern and check the IEEE 754 exponent/mantissa
to recover them.

# `equal`

Elementwise `equal` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `erf`

Elementwise `erf` of a GPU tensor.

# `exp`

Elementwise `exp` of a GPU tensor.

# `expm1`

Elementwise `expm1` of a GPU tensor.

# `floor`

Elementwise `floor` of a GPU tensor.

# `fused_chain`

Run a chain of up to 8 elementwise ops in a single shader dispatch.

Replaces N separate dispatches with one. Each binary step combines the
running register with `b`; each unary step transforms the register only.

    iex> {:ok, a} = Nx.Vulkan.upload_f32([1.0, 2.0, 3.0])
    iex> {:ok, b} = Nx.Vulkan.upload_f32([0.5, 0.5, 0.5])
    iex> # (a * b) + b → exp
    iex> {:ok, c} = Nx.Vulkan.fused_chain(a, b, [:multiply, :add, :exp])
    iex> {:ok, vals} = Nx.Vulkan.download_f32(c, 3)
    iex> vals  # exp((a*b)+b) = exp(1.0), exp(1.5), exp(2.0)
    [2.71828..., 4.48168..., 7.38905...]

Op atoms supported:

  * Binary (combine register with `b`): `:add`, `:multiply`, `:subtract`,
    `:divide`, `:pow`, `:max`, `:min`
  * Unary (transform register): `:exp`, `:log`, `:sqrt`, `:abs`,
    `:negate`, `:sigmoid`, `:tanh`, `:relu`, `:ceil`, `:floor`,
    `:sign`, `:reciprocal`, `:square`

Note: `:erf` (113) and `:expm1` (114) became fully functional in
the chain after spirit `161296d1` — `apply_unary` switched in cases
13 and 14. Earlier versions of the fused shader passed them through
unchanged.

Chains longer than 8 ops should be split: dispatch fused_chain twice
with the running tensor used as `a` for the second call.

# `fused_chain_4`

4-input fused chain. `ops_with_buf` items are either `{op_atom, idx}`
for binary (idx ∈ {1, 2, 3} for b/c/d) or plain `op_atom` for unary.
All 4 buffers must be the same byte size; up to 8 ops.

# `greater`

Elementwise `greater` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `has_f64?`

Returns true if the selected device supports f64 (double precision).

# `init`

Initialize the Vulkan compute context. Call once at startup.
Returns `:ok` on success, `{:error, reason}` if no Vulkan-capable
device is found.

# `jit`

JIT-compile a function so each op dispatches through the Vulkan backend.

Symmetric counterpart of `EXLA.jit/2` and `EMLX.jit/2`. There's no
kernel fusion in v0.1 — each `Nx.*` call inside the defn becomes one
shader dispatch via `Nx.Defn.Evaluator`. Combined-shader fusion is the
v0.2 work (see FUSION_RESEARCH.md).

Sets `Nx.Vulkan.Backend` as the global default if it isn't already, so
scalars and tensors created inside the defn land on the GPU. Calls
`Nx.Vulkan.init/0` (idempotent).

    iex> Nx.Vulkan.init()
    :ok
    iex> f = fn x -> Nx.add(x, x) end
    iex> Nx.Vulkan.jit(f).(Nx.tensor([1.0, 2.0]))
    #Nx.Tensor<f32[2] [2.0, 4.0]>

# `kinetic_energy`

Fused kinetic-energy primitive: `0.5 * sum(p² * inv_mass)` reduced
per workgroup. Returns a buffer of `ceil(n/256)` partial f32 sums;
caller does the final reduction (typically via `Nx.Vulkan.sum/1` or
on the host).

# `leapfrog_chain_cauchy`

Phase 2 chain shader for Cauchy(loc, scale). Returns 4-tuple of refs.
`log_pi_scale` is precomputed as `−log(π · scale)`.

# `leapfrog_chain_exponential`

Phase 2 sibling of `leapfrog_chain_normal/7` for the Exponential(lambda)
family on the unconstrained line (log-transform). Same I/O shape: returns
`{q_chain_ref, p_chain_ref, grad_chain_ref, logp_chain_ref}`.

Closed-form unconstrained gradient: `grad_q_uc = 1 - lambda * exp(q_uc)`.
`n ≤ 256` (single workgroup); see `leapfrog_chain_normal_lg/7` for the
multi-workgroup pattern when an `_lg` exponential variant is needed.

# `leapfrog_chain_halfnormal`

Phase 2 chain shader for HalfNormal(σ) on the unconstrained line
via log-transform `q_uc = log(q)`. Returns 4-tuple of refs.
`log_const` is precomputed as `−log(σ) − ½ log(π)`.

**Numerical caveat**: the gradient `1 − exp(2·q_uc)/σ²` overflows
in f32 when `q_uc > ~44`; for σ ≈ 1 the unconstrained range is
comfortably small.

# `leapfrog_chain_normal`

Fused **K-step chain** of NUTS leapfrog steps for a univariate Normal
log-density model. Performs `k` consecutive leapfrog steps in one Vulkan
dispatch and returns all `k` intermediate states:
`{q_chain_ref, p_chain_ref, grad_chain_ref, logp_chain_ref}`.

- `q_chain`, `p_chain`, `grad_chain` — each `k * n` f32 elements,
  laid out row-major (`step k, dimension i` at offset `k*n + i`).
- `logp_chain` — `k` f32 elements; per-step log-density reduced
  across the `n` dimensions.

Per-step amortized cost is `(per_dispatch_baseline + k * compute) / k`.
At `k=32` on the dev box this is ~16 µs per leapfrog step vs ~537 µs
for the single-step `leapfrog_normal` and ~6000 µs for the unfused
IR-walker path.

Constraints (Phase 1.5):
- `n ≤ 256` (single workgroup; multi-workgroup version is future work).
- f32 only; long chains (`k ≥ 64`) may accumulate measurable drift
  relative to a f64 reference.
- Univariate Normal log-density only — closed-form gradient
  `−(q − mu) / sigma²` baked into the shader.

# `leapfrog_chain_normal_f64`

f64 sibling of `leapfrog_chain_normal/7`. Same I/O contract but all
buffers use 8 bytes per element (input refs must be f64-typed Vulkan
tensors). Useful when chain integration needs higher precision than
f32 (e.g., long chains, sensitive log-densities).

# `leapfrog_chain_normal_lg`

Multi-workgroup variant of `leapfrog_chain_normal/7` for `n > 256`.
Returns `{q_chain_ref, p_chain_ref, grad_chain_ref, partial_logp_ref}`
where `partial_logp_ref` is a buffer of `K * num_workgroups` f32 floats
(per-workgroup partial sums per step). The caller does the per-step sum
across the `num_workgroups` axis to recover the final per-step logp.
Workgroup 0 includes the constant term so the host sum gives final logp
directly.

# `leapfrog_chain_studentt`

Phase 2 chain shader for Student-t(ν, μ, σ). Returns 4-tuple
`{q_chain_ref, p_chain_ref, grad_chain_ref, logp_chain_ref}`.
`logp_const` should be precomputed by the caller as
`log Γ((ν+1)/2) − log Γ(ν/2) − ½ log(πν) − log σ`.

# `leapfrog_chain_weibull`

Phase 2 chain shader for Weibull(k, lambda) on the unconstrained
line via log-transform `q_uc = log(q)`. Returns 4-tuple of refs.

Closed-form gradient: `∇logp(q_uc) = k · (1 − (exp(q_uc)/lambda)^k)`.
`logp_const` is precomputed as `n · (log(k) − k · log(lambda))` —
no `lgamma` in the shader.

# `leapfrog_normal`

Fused NUTS leapfrog step for a univariate Normal log-density model.
One Vulkan dispatch per leapfrog step instead of ~12 elementwise
dispatches via the IR walker. Returns `{q_new_ref, p_new_ref}`.

`q_ref`, `p_ref`, `inv_mass_ref` are f32 buffers of identical size.
`eps`, `mu`, `sigma` are scalars (f32 in the shader push constants;
f64 here for caller convenience). f32 only.

Closed-form gradient:
`grad_q log N(q | mu, sigma) = -(q - mu) / sigma²` — no autodiff
machinery in the shader.

# `less`

Elementwise `less` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `log`

Elementwise `log` of a GPU tensor.

# `logsumexp`

Numerically-stable logsumexp over a single virtual reduce axis.
`log(sum(exp(x - max(x))))`-shape inside one shader dispatch via the
two-pass shader. f32 only.

# `matmul`

Matrix multiply: `C[M*N] = A[M*K] · B[K*N]`. All row-major f32.
Returns `{:ok, c_tensor}`.

Auto-selects the best shader variant based on `(M, N, K)`:

  * Tiny (M*N*K < 4096): naive `matmul.spv` — dispatch overhead
    dominates; tiling adds no win.
  * Medium (4096 ≤ M*N*K < 256³): `matmul_tiled.spv` (16×16 shared-
    memory tiles) — good cache behavior, modest GPU occupancy.
  * Large (M*N*K ≥ 256³ ≈ 16M): `matmul_tiled16x2.spv` — each thread
    computes 2 output rows; mac-248 measured **4.2× win at 1024×1024**
    vs the naive variant.

`matmul_tiled32.spv` exists in spirit too but only wins on Ampere+
(1024 threads/SM); on Kepler/Maxwell it loses to 16x2 due to shared
memory pressure (8 KB tile vs 3 KB). Not auto-selected; reachable
via `matmul_variant/6`.

# `matmul_variant`

Matrix multiply with explicit shader variant. Use when you know
better than the heuristic, or to benchmark.

    :matmul                 # naive, gx=ceil(N/16), gy=ceil(M/16)
    :matmul_tiled           # 16×16 shared-mem tiles
    :matmul_tiled32         # 32×32 tiles (Ampere wins)
    :matmul_tiled16x2       # 32×16 output (2 rows per thread)

# `max`

Elementwise `max` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `mean`

Mean of all elements (sum + host-side divide).

# `min`

Elementwise `min` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `multiply`

Elementwise `multiply` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `negate`

Elementwise `negate` of a GPU tensor.

# `normal`

Generate `n` standard-normal N(0,1) f32 values via Box-Muller.

# `normal_logpdf`

Fused Normal log-density primitive:
`-0.5*((x-mu)/sigma)² - log(sigma) - 0.5*log(2π)`.
Output shape matches `x`. f32 only.

# `pick_matmul`

Picks the best matmul shader for a given `(M, N, K)` shape. Returns
`{shader_name, tile_m, tile_n}`. Public so benchmarks can introspect
the heuristic.

# `pool_clear`

Release every pooled VkBuf back to the device. Call at idle time to
reclaim memory; otherwise the pool grows to working-set size and stays
there. Idempotent.

# `pool_stats`

Buffer pool stats. Returns `{:ok, %{hits, misses, freed,
size_classes, total_pooled}}`. `hits/misses` count alloc requests
served from / missed by the pool; `freed` counts buffers actually
vkFreeMemory'd (pool-overflow or explicit clear); `size_classes` is
the number of distinct sizes currently held; `total_pooled` is the
total VkBuf count waiting for reuse.

    iex> Nx.Vulkan.init()
    iex> Nx.Vulkan.pool_stats()
    {:ok, %{hits: _, misses: _, freed: _, size_classes: _, total_pooled: _}}

# `pow`

Elementwise `pow` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `reciprocal`

Elementwise `reciprocal` of a GPU tensor.

# `reduce_axis`

Per-axis reduction over a virtual 3-D layout (outer, reduce, inner).
`op`: 0=sum, 1=max, 2=min. Output is (outer * inner) f32.

# `reduce_axis_f64`

f64 per-axis reduction; dispatches reduce_axis_f64.spv.

# `reduce_max`

Max of all elements.

# `reduce_min`

Min of all elements.

# `relu`

Elementwise `relu` of a GPU tensor.

# `select`

Branchless select: `cond_true_or_false ? t : f`. `cond` is a 0/1
tensor (typically the output of `equal/2`, `less/2`, `greater/2`).

Implemented compositionally as `cond * t + (1 - cond) * f`. Once
the v0.1 broadcast shader supports scalar broadcast we'll switch
to a 3-input shader; today this composition is the right shape
and adds two dispatches' worth of overhead per select.

# `sigmoid`

Elementwise `sigmoid` of a GPU tensor.

# `sign`

Elementwise `sign` of a GPU tensor.

# `sqrt`

Elementwise `sqrt` of a GPU tensor.

# `square`

Elementwise `square` of a GPU tensor.

# `subtract`

Elementwise `subtract` of two GPU tensors of equal length.
Returns `{:ok, tensor}` or `{:error, reason}`.

# `sum`

Sum of all elements (returns a host-side f32).

# `tanh`

Elementwise `tanh` of a GPU tensor.

# `transpose_2d`

2D transpose: `c = a^T` where `a` is M×N and `c` is N×M, both
row-major f32. Returns `{:ok, c_tensor}`.

# `uniform`

Generate `n` uniform [0,1) f32 values, deterministic via `seed`.

# `upload_binary`

Upload a raw binary (already packed f32 little-endian) to GPU memory.

# `upload_f32`

Upload a list of f32 values to a freshly-allocated GPU buffer.
Returns `{:ok, tensor_ref}` where `tensor_ref` is an opaque
`ResourceArc` whose underlying VkBuf is freed when GC'd.

    iex> Nx.Vulkan.init()
    :ok
    iex> {:ok, t} = Nx.Vulkan.upload_f32([1.0, 2.0, 3.0, 4.0])
    iex> {:ok, [1.0, 2.0, 3.0, 4.0]} = Nx.Vulkan.download_f32(t, 4)

---

*Consult [api-reference.md](api-reference.md) for complete listing*
