Nx.Vulkan.NativeV (nx_vulkan v0.1.0)

Copy Markdown View Source

Rustler NIF for the pure-Rust (vulkano) compute backend.

Sibling of Nx.Vulkan.Native (the C++/spirit-backed NIF). Same chain-shader dispatch contract, but resource lifetimes are managed by Rust ownership (Arc<Buffer>) rather than opaque VkBuf* pointers — so the stale-handle bug class that surfaced as ArgumentError in Nx.Vulkan.Backend.to_binary (Mission II R4) is structurally absent.

This module is the spike landing zone — for now it only exposes leapfrog_chain_synth/6, taking bytes in and bytes out (no persistent ResourceArc tensor handles). When the vulkano backend gets feature-parity with the C++ path, this expands to cover apply_binary, reduce, etc.

Compatibility

  • Loads any SPV the existing pipeline emits (verified byte-for-byte equivalence against Nx.Vulkan.Native.leapfrog_chain_synth on the regime-model R3 fixture; see nx_vulkan/spike/vulkano_synth/README.md).
  • Builds on Linux + FreeBSD 15.0 with vulkano 0.34.

Summary

Functions

Elementwise binary op. op_code selects which operation the shader executes via a specialisation constant

Elementwise unary op. op_code selects

Allocate a zero-initialised device buffer of n_bytes. Returns {:ok, ref}.

Buffer size in bytes (returns integer, never crashes on a valid resource).

Read a device buffer back to a host binary. Returns {:ok, binary}.

Allocate a device buffer + upload data to it. Returns {:ok, ref}. The ref is a Rustler resource that owns the underlying Arc<Buffer> — when the BEAM GCs it, vulkano's Drop runs and the GPU memory is freed.

Overwrite an existing device buffer with new host data. Returns :ok or {:error, :size_mismatch} when sizes disagree.

Dispatch a K-step leapfrog chain against the synthesised SPV.

2D matmul. C = A · B where A is M×K row-major, B is K×N row-major, C is M×N row-major. All f32. Buffers: a (mk4), b (kn4), out (mn4).

Per-axis reduction. op_code: 0=sum, 1=max, 2=min.

2D transpose. Input A is M×N row-major; output is N×M row-major. Buffers: a (m*n*4 bytes), out (m*n*4 bytes).

Functions

apply_binary(out, a, b, n, op_code, spv_path)

Elementwise binary op. op_code selects which operation the shader executes via a specialisation constant:

0=add  1=mul  2=sub  3=div  4=pow  5=max  6=min

Buffers must all be the same byte size. Returns :ok or {:error, :size_mismatch} / {:error, :dispatch_failed, msg}.

apply_unary(out, a, n, op_code, spv_path)

Elementwise unary op. op_code selects:

0=exp  1=log  2=sqrt  3=abs  4=neg  5=sigmoid  6=tanh  7=relu
8=ceil  9=floor  10=sign  11=reciprocal  12=square

Buffers must be the same byte size.

buf_alloc(n_bytes)

Allocate a zero-initialised device buffer of n_bytes. Returns {:ok, ref}.

buf_byte_size(ref)

Buffer size in bytes (returns integer, never crashes on a valid resource).

buf_download(ref)

Read a device buffer back to a host binary. Returns {:ok, binary}.

buf_upload(data)

Allocate a device buffer + upload data to it. Returns {:ok, ref}. The ref is a Rustler resource that owns the underlying Arc<Buffer> — when the BEAM GCs it, vulkano's Drop runs and the GPU memory is freed.

buf_upload_into(ref, data)

Overwrite an existing device buffer with new host data. Returns :ok or {:error, :size_mismatch} when sizes disagree.

leapfrog_chain_synth(q, p, extras, push, k, spv_path)

Dispatch a K-step leapfrog chain against the synthesised SPV.

All inputs are binaries:

  • q_init, p_init: d * 4 bytes each (little-endian f32)
  • extras: (n_obs + d) * 4 bytes — obs followed by inv_mass in the R2.2.3 packed layout
  • push: 20–128 bytes, the synth template's push block (K, n_obs, d, _pad, eps)
  • k: leapfrog steps per dispatch (must match push.K)
  • spv_path: filesystem path to the cached SPV blob

Returns {:ok, {q_chain_bin, p_chain_bin, grad_chain_bin, logp_chain_bin}} on success — same shape as the C++ NIF's return after download_binary_batch4/4.

matmul(out, a, b, m, n, k, spv_path)

2D matmul. C = A · B where A is M×K row-major, B is K×N row-major, C is M×N row-major. All f32. Buffers: a (mk4), b (kn4), out (mn4).

reduce_axis(out, a, outer, reduce_size, inner, op_code, spv_path)

Per-axis reduction. op_code: 0=sum, 1=max, 2=min.

Input shape is interpreted as a virtual (outer, reduce_size, inner) tensor; output shape is (outer, inner) — i.e. the reduction collapses the middle axis. For full reductions use outer=1, reduce_size=n, inner=1.

Buffers: out has outer * inner elements; a has outer * reduce_size * inner elements.

transpose_2d(out, a, m, n, spv_path)

2D transpose. Input A is M×N row-major; output is N×M row-major. Buffers: a (m*n*4 bytes), out (m*n*4 bytes).