Path A.2 — compile-time fusion of elementwise op chains.
Given a 2-arg function whose body is a chain of Nx.* calls, rewrites
it at macro time into a single Nx.Vulkan.fused_chain/3 dispatch.
Replaces N shader dispatches with one.
Example
import Nx.Vulkan.Fuse
f = fuse(fn a, b -> Nx.exp(Nx.add(Nx.multiply(a, b), b)) end)
{:ok, c} = f.(a_ref, b_ref)
# one fused dispatch instead of threeRecognized pattern
The macro recognizes function bodies of the form:
Nx.<op_n>(Nx.<op_{n-1}>(... Nx.<op_1>(a, b) ...))where:
aandbare the two function arguments (bmay be reused for every binary op).- Each
op_kis from the supported elementwise set:- Binary (combine register with
b)::add,:subtract,:multiply,:divide,:pow,:max,:min - Unary (transform register):
:exp,:log,:sqrt,:abs,:negate,:sigmoid,:tanh,:relu,:ceil,:floor,:sign,:reciprocal,:square
- Binary (combine register with
Chains longer than 8 ops fall back to non-fused composition (one shader dispatch per op).
Limitations (v0.2 work)
- No autograd integration —
fusereturns aNx.Vulkanref tuple, not a defn-traceable value. - Binary ops only see
bas the second operand.Nx.add(a, c)wherecis a third tensor doesn't fuse. - No reshape/broadcast fusion. Only same-shape elementwise.
- Auto-detection inside
defnblocks requires a realNx.Defn.Compiler. That's the v0.2 follow-up that recognizes chains in any defn body without macro opt-in.
Summary
Functions
Macro entry point. See module docs.