View Source Nx.Defn.Compiler behaviour (Nx v0.5.3)

The specification and helper functions for custom defn compilers.

Link to this section Summary

Callbacks

Callback for compilation.

Callback for compilation.

Receives a keyword list of compiler options and returns a list of compiler options, each to run on a separate partition/device.

Callback for streaming (on top of JIT compilation).

Functions

Returns the current compiler.

Returns if we are inside defn at compilation time.

Link to this section Callbacks

Link to this callback

__compile__(key, vars, fun, opts)

View Source
@callback __compile__(
  key :: term(),
  vars :: vars,
  fun :: (vars -> Nx.Container.t()),
  opts :: keyword()
) :: ([[Nx.Tensor.t()]] -> [Nx.Container.t()])
when vars: [Nx.Container.t()]

Callback for compilation.

It receives an opaque key used for caching, the function vars, the function fun which builds a defn expression, and the compiler options. It must call fun with the vars as arguments.

It returns a function that receives a list of arguments and returns a list of results.

The callback uses double underscores so it can be defined at root modules without affecting the module's main API.

Link to this callback

__jit__(key, vars, fun, args_list, opts)

View Source
@callback __jit__(
  key :: term(),
  vars,
  fun :: (vars -> Nx.Container.t()),
  args_list :: [[(-> Nx.Tensor.t())]],
  opts :: keyword()
) :: [Nx.Container.t()]
when vars: [Nx.Container.t()]

Callback for compilation.

It receives an opaque key used for caching, the function vars, the function fun which builds a defn expression, a list of argument list in args_list, and the compiler options.

It must call fun with the vars as arguments. Note the key does not include the vars in its cache. Therefore, if you want to cache the result of fun.(vars), you likely want to include the vars in the cache key. vars is a list of containers expressions.

Once the expression is built and compiled, it must be invoked for each list of arguments in args_list. In a nutshell, vars are used to build the expression from fun which is then invoked for each list of arguments in args_list. All lists in args_list are guaranteed to be flat lists of the same length, containing zero-arity functions that return tensors of the same type, shape, and name.

The callback uses double underscores so it can be defined at root modules without affecting the module's main API.

Link to this callback

__partitions_options__(keyword)

View Source
@callback __partitions_options__(keyword()) :: [keyword()]

Receives a keyword list of compiler options and returns a list of compiler options, each to run on a separate partition/device.

Link to this callback

__stream__(key, input, acc, vars, fun, args_list, opts)

View Source
@callback __stream__(
  key :: term(),
  input,
  acc,
  vars,
  fun :: (vars -> {output, acc}),
  args_list :: [[(-> Nx.t())]],
  opts :: keyword()
) :: [Nx.Stream.t()]
when input: Nx.Container.t(),
     output: Nx.Container.t(),
     acc: Nx.Container.t(),
     vars: [Nx.Container.t()]

Callback for streaming (on top of JIT compilation).

It receives the same arguments as __jit__/5 with the addition of the streaming input and accumulator templates. If the input and accumulator are containers, they are kept in their container shapes. As in __jit__/5, both vars and args_list are flat lists of tensors (without their container shape).

It must return a struct that implements the Nx.Stream protocol.

Link to this section Functions

Returns the current compiler.

Returns nil if we are not inside defn.

Returns if we are inside defn at compilation time.

This would be invoked inside a macro that has specific defn logic.