Nx.Defn.Compiler behaviour (Nx v0.1.0) View Source

The specification and helper functions for custom defn compilers.

Link to this section Summary

Callbacks

Callback for JIT compilation.

Callback for streaming (on top of JIT compilation).

Functions

Returns the current compiler.

Link to this section Types

Specs

expr() :: Nx.t() | tuple() | %{optional(term()) => expr()}

Link to this section Callbacks

Link to this callback

__jit__( key, vars, function, opts )

View Source

Specs

__jit__(
  key :: term(),
  vars :: [Nx.t()],
  ([Nx.t()] -> expr()),
  opts :: keyword()
) :: expr()

Callback for JIT compilation.

It receives an opaque key used for caching, the function vars, the function which builds an expression, and the compiler options.

It must call fun with the vars as a list of arguments. Note the key does not include the vars in its cache. Therefore, if you want to cache the result of fun.(vars), you likely want to include the vars in the cache key. Given vars are all tensors, it is often a matter of retrieving its type, shape, and names.

The callback uses double underscores so it can be defined at root modules without affecting the module's main API.

Link to this callback

__stream__( key, stream, acc, vars, function, opts )

View Source

Specs

__stream__(
  key :: term(),
  stream,
  acc,
  vars :: [Nx.t()],
  ([Nx.t()] -> acc),
  opts :: keyword()
) :: Nx.Stream.t()
when stream: expr(), acc: expr()

Callback for streaming (on top of JIT compilation).

It receives the same arguments as __jit__/4 with the addition of the streaming and accumulator templates. It must return a struct that implements the Nx.Stream protocol.

Link to this section Functions

Returns the current compiler.

Returns nil if we are not inside defn.