View Source Nx.Defn.Compiler behaviour (Nx v0.2.1)
The specification and helper functions for custom defn
compilers.
Link to this section Summary
Link to this section Callbacks
@callback __jit__( key :: term(), vars :: [Nx.t()], fun :: ([Nx.t()] -> Nx.Container.t()), args_list :: [[Nx.t()]], opts :: keyword() ) :: [Nx.Container.t()]
Callback for JIT compilation.
It receives an opaque key
used for caching, the function
vars
, the function fun
which builds a defn expression,
a list of argument list in args_list
, and the compiler options.
It must call fun
with the vars
as arguments. Note the key
does not include the vars
in its cache. Therefore, if you want
to cache the result of fun.(vars)
, you likely want to include
the vars in the cache key. vars
is a flat list of tensor
templates, so they can be added directly as part of the cache
key or, most often, in function of their type and shape.
Once the expression is built and compiled, it must be invoked
for each list of arguments in args_list
. In a nutshell, vars
are used to build the expression from fun
which is then
invoked for each list of arguments in args_list
. All lists
in args_list
are guaranteed to be flat lists of the same length,
containing tensors of the same type, shape, and name.
The callback uses double underscores so it can be defined at root modules without affecting the module's main API.
@callback __stream__( key :: term(), input, acc, vars :: [Nx.t()], fun :: ([Nx.t()] -> {output, acc}), args_list :: [[Nx.t()]], opts :: keyword() ) :: [Nx.Stream.t()] when input: Nx.Container.t(), output: Nx.Container.t(), acc: Nx.Container.t()
Callback for streaming (on top of JIT compilation).
It receives the same arguments as __jit__/5
with the addition
of the streaming input and accumulator templates. If the input
and accumulator are containers, they are kept in their container
shapes. As in __jit__/5
, both vars
and args_list
are flat
lists of tensors (without their container shape).
It must return a struct that implements the Nx.Stream
protocol.
Link to this section Functions
Returns the current compiler.
Returns nil if we are not inside defn
.