Nx.Vulkan.Compiler (nx_vulkan v0.1.0)

Copy Markdown View Source

Path A.2 v2 (partial) — Nx.Defn.Compiler that auto-detects fusable elementwise chains and dispatches Nx.Vulkan.fused_chain/3 instead of N separate shader calls.

What it does today

At __jit__ time, calls fun.(vars) once to materialize the IR. Walks the result looking for a chain of supported elementwise ops whose only inputs are the two function arguments. If matched: skip Evaluator entirely and dispatch a single fused_chain call.

What it doesn't do yet

  • Multi-output: only single-output chains. A defn that returns a tuple falls through to Nx.Defn.Evaluator.
  • Branched chains: only linear chains. A node that's used twice falls through.
  • More than 2 vars: 2-arg functions only (matches the shader's two-input layout). Wider arities fall through.
  • Chains > 8 ops: shader limit; longer chains fall through.
  • Non-elementwise ops: any reduce/reshape/dot in the chain falls through.

All fall-through cases delegate to Nx.Defn.Evaluator so behavior stays correct — the worst case is "no fusion, same speed as before."

Configuration

config :exmc, :compiler, :vulkan

Exmc.JIT then routes to Nx.Vulkan.jit/2, which uses this compiler when available (defaults to Nx.Defn.Evaluator if unsupported).