Extension point for lowering selected Nx.block/4 blocks to XLA custom calls
(stablehlo.custom_call in StableHLO MLIR).
Other blocks (for example gather-based take or FFT) are lowered inline in
EXLA.Defn and do not use this protocol.
When EXLA.Defn calls it
During compilation with compiler: EXLA, each Nx.block(block, inputs, outputs, fn ... end)
is processed by this protocol. EXLA invokes call/4 once per block.
If call/4 returns :skip, EXLA compiles the block's default callback
(the anonymous function body) instead of emitting a custom call.
Default lowerings are provided for Nx.Block.LinAlg.QR and Nx.Block.LinAlg.Eigh.
call/4 arguments
Callback arity is call(struct, out, args, client), matching
Nx.block(block, inputs, outputs, fn ... end) (block, outputs, inputs, then client).
struct— the block passed as the first argument toNx.block/4(your owndefstructor an existing block such as%Nx.Block.LinAlg.QR{}).out— the output template tuple passed toNx.block/4(expression metadata for shapes and types, not runtime tensors).args— list of input templates, in the same order asinputsinNx.block/4.client— the activeEXLA.Client(use e.g.client.platformto gate host-only lowerings).
call/4 return value
:skip— this implementation does not apply (unsupported type, non-host platform, wrong arity, etc.). The default block implementation is used instead.{:ok, %EXLA.CustomCall.Spec{}}— emit a StableHLO custom call; seeEXLA.CustomCall.Specforcall_target_name, optionalattributes([{name, attr}]string pairs for thestablehlo.custom_callbackend_configdictionary), and optionaloperand_element_types(operand converts when they differ from the lowered inputs).
Dispatch
The protocol uses @fallback_to_any true. Built-in lowerings for known blocks
live in defimpl EXLA.CustomCall, for: Any. Your application or dependency can
add defimpl EXLA.CustomCall, for: YourStruct; that implementation is chosen
whenever the block is %YourStruct{}, instead of the Any fallback.
Native handlers
Emitting a custom call in MLIR is only half of the story: the target name
must be registered with XLA on the relevant platform (typically via a native
library loaded into the process). That registration is not configured
through config :exla, ...; you load or link the native code by the same
means you would for any other NIF-backed extension.
Example
defmodule MyApp.CustomQrBlock do
defstruct []
end
defimpl EXLA.CustomCall, for: MyApp.CustomQrBlock do
def call(_block, {%{type: {kind, size}}, _r_expr}, [_input], %{platform: :host})
when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do
{:ok, %EXLA.CustomCall.Spec{call_target_name: "my_custom_qr_target"}}
end
def call(_, _, _, _), do: :skip
endThen use Nx.block(%MyApp.CustomQrBlock{}, ...) inside a defn compiled with
compiler: EXLA.
Summary
Functions
Returns :skip or {:ok, %EXLA.CustomCall.Spec{}}.
Types
@type t() :: term()
All the types that implement this protocol.