Path A.2 v2 (partial) — Nx.Defn.Compiler that auto-detects fusable
elementwise chains and dispatches Nx.Vulkan.fused_chain/3 instead
of N separate shader calls.
What it does today
At __jit__ time, calls fun.(vars) once to materialize the IR.
Walks the result looking for a chain of supported elementwise ops
whose only inputs are the two function arguments. If matched: skip
Evaluator entirely and dispatch a single fused_chain call.
What it doesn't do yet
- Multi-output: only single-output chains. A defn that returns
a tuple falls through to
Nx.Defn.Evaluator. - Branched chains: only linear chains. A node that's used twice falls through.
- More than 2 vars: 2-arg functions only (matches the shader's two-input layout). Wider arities fall through.
- Chains > 8 ops: shader limit; longer chains fall through.
- Non-elementwise ops: any reduce/reshape/dot in the chain falls through.
All fall-through cases delegate to Nx.Defn.Evaluator so behavior
stays correct — the worst case is "no fusion, same speed as before."
Configuration
config :exmc, :compiler, :vulkanExmc.JIT then routes to Nx.Vulkan.jit/2, which uses this compiler
when available (defaults to Nx.Defn.Evaluator if unsupported).