Nx.Vulkan.Fuse (nx_vulkan v0.1.0)

Copy Markdown View Source

Path A.2 — compile-time fusion of elementwise op chains.

Given a 2-arg function whose body is a chain of Nx.* calls, rewrites it at macro time into a single Nx.Vulkan.fused_chain/3 dispatch. Replaces N shader dispatches with one.

Example

import Nx.Vulkan.Fuse

f = fuse(fn a, b -> Nx.exp(Nx.add(Nx.multiply(a, b), b)) end)
{:ok, c} = f.(a_ref, b_ref)
# one fused dispatch instead of three

Recognized pattern

The macro recognizes function bodies of the form:

Nx.<op_n>(Nx.<op_{n-1}>(... Nx.<op_1>(a, b) ...))

where:

  • a and b are the two function arguments (b may be reused for every binary op).
  • Each op_k is from the supported elementwise set:
    • Binary (combine register with b): :add, :subtract, :multiply, :divide, :pow, :max, :min
    • Unary (transform register): :exp, :log, :sqrt, :abs, :negate, :sigmoid, :tanh, :relu, :ceil, :floor, :sign, :reciprocal, :square

Chains longer than 8 ops fall back to non-fused composition (one shader dispatch per op).

Limitations (v0.2 work)

  • No autograd integration — fuse returns a Nx.Vulkan ref tuple, not a defn-traceable value.
  • Binary ops only see b as the second operand. Nx.add(a, c) where c is a third tensor doesn't fuse.
  • No reshape/broadcast fusion. Only same-shape elementwise.
  • Auto-detection inside defn blocks requires a real Nx.Defn.Compiler. That's the v0.2 follow-up that recognizes chains in any defn body without macro opt-in.

Summary

Functions

Macro entry point. See module docs.

Functions

fuse(orig)

(macro)

Macro entry point. See module docs.

Returns the original function unchanged if the body doesn't fit the recognized chain pattern — graceful fallback, never a compile error.