Nx.Defn.Compiler behaviour (Nx v0.1.0) View Source
The specification and helper functions for custom defn
compilers.
Link to this section Summary
Link to this section Types
Specs
Link to this section Callbacks
Specs
Callback for JIT compilation.
It receives an opaque key
used for caching, the function
vars
, the function which builds an expression, and the compiler
options.
It must call fun
with the vars as a list of arguments.
Note the key
does not include the vars
in its cache.
Therefore, if you want to cache the result of fun.(vars)
,
you likely want to include the vars in the cache key.
Given vars
are all tensors, it is often a matter of
retrieving its type, shape, and names.
The callback uses double underscores so it can be defined at root modules without affecting the module's main API.
Specs
__stream__( key :: term(), stream, acc, vars :: [Nx.t()], ([Nx.t()] -> acc), opts :: keyword() ) :: Nx.Stream.t() when stream: expr(), acc: expr()
Callback for streaming (on top of JIT compilation).
It receives the same arguments as __jit__/4
with the addition
of the streaming and accumulator templates. It must return a struct
that implements the Nx.Stream
protocol.
Link to this section Functions
Returns the current compiler.
Returns nil if we are not inside defn
.