Rustler NIF for the pure-Rust (vulkano) compute backend.
Sibling of Nx.Vulkan.Native (the C++/spirit-backed NIF). Same
chain-shader dispatch contract, but resource lifetimes are
managed by Rust ownership (Arc<Buffer>) rather than opaque
VkBuf* pointers — so the stale-handle bug class that
surfaced as ArgumentError in Nx.Vulkan.Backend.to_binary
(Mission II R4) is structurally absent.
This module is the spike landing zone — for now it only
exposes leapfrog_chain_synth/6, taking bytes in and bytes
out (no persistent ResourceArc tensor handles). When the
vulkano backend gets feature-parity with the C++ path, this
expands to cover apply_binary, reduce, etc.
Compatibility
- Loads any SPV the existing pipeline emits (verified
byte-for-byte equivalence against
Nx.Vulkan.Native.leapfrog_chain_synthon the regime-model R3 fixture; seenx_vulkan/spike/vulkano_synth/README.md). - Builds on Linux + FreeBSD 15.0 with vulkano 0.34.
Summary
Functions
Elementwise binary op. op_code selects which operation the
shader executes via a specialisation constant
Elementwise unary op. op_code selects
Allocate a zero-initialised device buffer of n_bytes. Returns {:ok, ref}.
Buffer size in bytes (returns integer, never crashes on a valid resource).
Read a device buffer back to a host binary. Returns {:ok, binary}.
Allocate a device buffer + upload data to it. Returns
{:ok, ref}. The ref is a Rustler resource that owns the
underlying Arc<Buffer> — when the BEAM GCs it, vulkano's Drop
runs and the GPU memory is freed.
Overwrite an existing device buffer with new host data.
Returns :ok or {:error, :size_mismatch} when sizes disagree.
Dispatch a K-step leapfrog chain against the synthesised SPV.
2D matmul. C = A · B where A is M×K row-major, B is K×N row-major, C is M×N row-major. All f32. Buffers: a (mk4), b (kn4), out (mn4).
Per-axis reduction. op_code: 0=sum, 1=max, 2=min.
2D transpose. Input A is M×N row-major; output is N×M row-major.
Buffers: a (m*n*4 bytes), out (m*n*4 bytes).
Functions
Elementwise binary op. op_code selects which operation the
shader executes via a specialisation constant:
0=add 1=mul 2=sub 3=div 4=pow 5=max 6=minBuffers must all be the same byte size. Returns :ok or
{:error, :size_mismatch} / {:error, :dispatch_failed, msg}.
Elementwise unary op. op_code selects:
0=exp 1=log 2=sqrt 3=abs 4=neg 5=sigmoid 6=tanh 7=relu
8=ceil 9=floor 10=sign 11=reciprocal 12=squareBuffers must be the same byte size.
Allocate a zero-initialised device buffer of n_bytes. Returns {:ok, ref}.
Buffer size in bytes (returns integer, never crashes on a valid resource).
Read a device buffer back to a host binary. Returns {:ok, binary}.
Allocate a device buffer + upload data to it. Returns
{:ok, ref}. The ref is a Rustler resource that owns the
underlying Arc<Buffer> — when the BEAM GCs it, vulkano's Drop
runs and the GPU memory is freed.
Overwrite an existing device buffer with new host data.
Returns :ok or {:error, :size_mismatch} when sizes disagree.
Dispatch a K-step leapfrog chain against the synthesised SPV.
All inputs are binaries:
q_init,p_init: d * 4 bytes each (little-endian f32)extras: (n_obs + d) * 4 bytes — obs followed by inv_mass in the R2.2.3 packed layoutpush: 20–128 bytes, the synth template's push block (K, n_obs, d, _pad, eps)k: leapfrog steps per dispatch (must match push.K)spv_path: filesystem path to the cached SPV blob
Returns {:ok, {q_chain_bin, p_chain_bin, grad_chain_bin, logp_chain_bin}} on success — same shape as the C++ NIF's
return after download_binary_batch4/4.
2D matmul. C = A · B where A is M×K row-major, B is K×N row-major, C is M×N row-major. All f32. Buffers: a (mk4), b (kn4), out (mn4).
Per-axis reduction. op_code: 0=sum, 1=max, 2=min.
Input shape is interpreted as a virtual (outer, reduce_size, inner)
tensor; output shape is (outer, inner) — i.e. the reduction
collapses the middle axis. For full reductions use
outer=1, reduce_size=n, inner=1.
Buffers: out has outer * inner elements; a has
outer * reduce_size * inner elements.
2D transpose. Input A is M×N row-major; output is N×M row-major.
Buffers: a (m*n*4 bytes), out (m*n*4 bytes).