View Source Nx.Defn.Compiler behaviour (Nx v0.4.0)
The specification and helper functions for custom defn
compilers.
Link to this section Summary
Callbacks
Callback for compilation.
Callback for compilation.
Callback for streaming (on top of JIT compilation).
Link to this section Callbacks
@callback __compile__( key :: term(), vars :: [Nx.Container.t()], fun :: ([Nx.Container.t()] -> Nx.Container.t()), opts :: keyword() ) :: ([[Nx.t()]] -> [Nx.Container.t()])
Callback for compilation.
It receives an opaque key
used for caching, the function
vars
, the function fun
which builds a defn expression,
and the compiler options. It must call fun
with the vars
as arguments.
It returns a function that receives a list of arguments and returns a list of results.
The callback uses double underscores so it can be defined at root modules without affecting the module's main API.
@callback __jit__( key :: term(), vars :: [Nx.Container.t()], fun :: ([Nx.Container.t()] -> Nx.Container.t()), args_list :: [[(-> Nx.t())]], opts :: keyword() ) :: [Nx.Container.t()]
Callback for compilation.
It receives an opaque key
used for caching, the function
vars
, the function fun
which builds a defn expression,
a list of argument list in args_list
, and the compiler options.
It must call fun
with the vars
as arguments. Note the key
does not include the vars
in its cache. Therefore, if you want
to cache the result of fun.(vars)
, you likely want to include
the vars in the cache key. vars
is a list of containers expressions.
Once the expression is built and compiled, it must be invoked
for each list of arguments in args_list
. In a nutshell, vars
are used to build the expression from fun
which is then
invoked for each list of arguments in args_list
. All lists
in args_list
are guaranteed to be flat lists of the same length,
containing zero-arity functions that return tensors of the same type,
shape, and name.
The callback uses double underscores so it can be defined at root modules without affecting the module's main API.
@callback __stream__( key :: term(), input, acc, vars :: [Nx.t()], fun :: ([Nx.t()] -> {output, acc}), args_list :: [[(-> Nx.t())]], opts :: keyword() ) :: [Nx.Stream.t()] when input: Nx.Container.t(), output: Nx.Container.t(), acc: Nx.Container.t()
Callback for streaming (on top of JIT compilation).
It receives the same arguments as __jit__/5
with the addition
of the streaming input and accumulator templates. If the input
and accumulator are containers, they are kept in their container
shapes. As in __jit__/5
, both vars
and args_list
are flat
lists of tensors (without their container shape).
It must return a struct that implements the Nx.Stream
protocol.
Link to this section Functions
Returns the current compiler.
Returns nil if we are not inside defn
.