Nx.Backend behaviour (Nx v0.11.0)

Copy Markdown View Source

The behaviour for tensor backends.

Each backend is module that defines a struct and implements the callbacks defined in this module. The callbacks are mostly implementations of the functions in the Nx module with the tensor output shape given as first argument.

Nx backends come in two flavors: opaque backends, of which you should not access its data directly except through the functions in the Nx module, and public ones, of which its data can be directly accessed and traversed. The former typically have the Backend suffix.

Nx ships with the following backends:

  • Nx.BinaryBackend - an opaque backend written in pure Elixir that stores the data in Elixir's binaries. This is the default backend used by the Nx module. The backend itself (and its data) is private and must not be accessed directly.

  • Nx.TemplateBackend - an opaque backend written that works as a template in APIs to declare the type, shape, and names of tensors to be expected in the future.

  • Nx.Defn.Expr - a public backend used by defn to build expression trees that are traversed by custom compilers.

This module also includes functions that are meant to be shared across backends.

Summary

Callbacks

Invoked for execution of optional callbacks with a default implementation.

Functions

Inspects the given tensor given by binary.

Types

axes()

@type axes() :: Nx.Tensor.axes()

axis()

@type axis() :: Nx.Tensor.axis()

backend_options()

@type backend_options() :: term()

shape()

@type shape() :: Nx.Tensor.shape()

t()

@type t() :: %{__struct__: atom()}

tensor()

@type tensor() :: Nx.Tensor.t()

Callbacks

abs(out, tensor)

@callback abs(out :: tensor(), tensor()) :: tensor()

acos(out, tensor)

@callback acos(out :: tensor(), tensor()) :: tensor()

acosh(out, tensor)

@callback acosh(out :: tensor(), tensor()) :: tensor()

add(out, tensor, tensor)

@callback add(out :: tensor(), tensor(), tensor()) :: tensor()

all(out, tensor, keyword)

@callback all(out :: tensor(), tensor(), keyword()) :: tensor()

all_close(out, tensor, tensor, keyword)

(optional)
@callback all_close(out :: tensor(), tensor(), tensor(), keyword()) :: tensor()

any(out, tensor, keyword)

@callback any(out :: tensor(), tensor(), keyword()) :: tensor()

argmax(out, tensor, keyword)

@callback argmax(out :: tensor(), tensor(), keyword()) :: tensor()

argmin(out, tensor, keyword)

@callback argmin(out :: tensor(), tensor(), keyword()) :: tensor()

argsort(out, tensor, keyword)

@callback argsort(out :: tensor(), tensor(), keyword()) :: tensor()

as_type(out, tensor)

@callback as_type(out :: tensor(), tensor()) :: tensor()

asin(out, tensor)

@callback asin(out :: tensor(), tensor()) :: tensor()

asinh(out, tensor)

@callback asinh(out :: tensor(), tensor()) :: tensor()

atan2(out, tensor, tensor)

@callback atan2(out :: tensor(), tensor(), tensor()) :: tensor()

atan(out, tensor)

@callback atan(out :: tensor(), tensor()) :: tensor()

atanh(out, tensor)

@callback atanh(out :: tensor(), tensor()) :: tensor()

backend_copy(tensor, module, backend_options)

@callback backend_copy(tensor(), module(), backend_options()) :: tensor()

backend_deallocate(tensor)

@callback backend_deallocate(tensor()) :: :ok | :already_deallocated

backend_transfer(tensor, module, backend_options)

@callback backend_transfer(tensor(), module(), backend_options()) :: tensor()

bitcast(out, tensor)

@callback bitcast(out :: tensor(), tensor()) :: tensor()

bitwise_and(out, tensor, tensor)

@callback bitwise_and(out :: tensor(), tensor(), tensor()) :: tensor()

bitwise_not(out, tensor)

@callback bitwise_not(out :: tensor(), tensor()) :: tensor()

bitwise_or(out, tensor, tensor)

@callback bitwise_or(out :: tensor(), tensor(), tensor()) :: tensor()

bitwise_xor(out, tensor, tensor)

@callback bitwise_xor(out :: tensor(), tensor(), tensor()) :: tensor()

broadcast(out, tensor, shape, axes)

@callback broadcast(out :: tensor(), tensor(), shape(), axes()) :: tensor()

cbrt(out, tensor)

@callback cbrt(out :: tensor(), tensor()) :: tensor()

ceil(out, tensor)

@callback ceil(out :: tensor(), tensor()) :: tensor()

cholesky(out, tensor)

(optional)
@callback cholesky(out :: tensor(), tensor()) :: tensor()

clip(out, tensor, min, max)

@callback clip(out :: tensor(), tensor(), min :: tensor(), max :: tensor()) :: tensor()

concatenate(out, tensor, axis)

@callback concatenate(out :: tensor(), tensor(), axis()) :: tensor()

conjugate(out, tensor)

@callback conjugate(out :: tensor(), tensor()) :: tensor()

constant(out, arg2, backend_options)

@callback constant(out :: tensor(), number() | Complex.t(), backend_options()) :: tensor()

conv(out, tensor, kernel, keyword)

@callback conv(out :: tensor(), tensor(), kernel :: tensor(), keyword()) :: tensor()

cos(out, tensor)

@callback cos(out :: tensor(), tensor()) :: tensor()

cosh(out, tensor)

@callback cosh(out :: tensor(), tensor()) :: tensor()

count_leading_zeros(out, tensor)

@callback count_leading_zeros(out :: tensor(), tensor()) :: tensor()

cumulative_max(out, t, keyword)

(optional)
@callback cumulative_max(out :: tensor(), t :: tensor(), keyword()) :: tensor()

cumulative_min(out, t, keyword)

(optional)
@callback cumulative_min(out :: tensor(), t :: tensor(), keyword()) :: tensor()

cumulative_product(out, t, keyword)

(optional)
@callback cumulative_product(out :: tensor(), t :: tensor(), keyword()) :: tensor()

cumulative_sum(out, t, keyword)

(optional)
@callback cumulative_sum(out :: tensor(), t :: tensor(), keyword()) :: tensor()

determinant(out, t)

(optional)
@callback determinant(out :: tensor(), t :: tensor()) :: tensor()

divide(out, tensor, tensor)

@callback divide(out :: tensor(), tensor(), tensor()) :: tensor()

dot(out, tensor, axes, axes, tensor, axes, axes)

@callback dot(out :: tensor(), tensor(), axes(), axes(), tensor(), axes(), axes()) ::
  tensor()

eigh({}, tensor, keyword)

(optional)
@callback eigh({eigenvals :: tensor(), eigenvecs :: tensor()}, tensor(), keyword()) ::
  tensor()

equal(out, tensor, tensor)

@callback equal(out :: tensor(), tensor(), tensor()) :: tensor()

erf(out, tensor)

@callback erf(out :: tensor(), tensor()) :: tensor()

erf_inv(out, tensor)

@callback erf_inv(out :: tensor(), tensor()) :: tensor()

erfc(out, tensor)

@callback erfc(out :: tensor(), tensor()) :: tensor()

exp(out, tensor)

@callback exp(out :: tensor(), tensor()) :: tensor()

expm1(out, tensor)

@callback expm1(out :: tensor(), tensor()) :: tensor()

eye(tensor, backend_options)

@callback eye(tensor(), backend_options()) :: tensor()

fft2(out, tensor, keyword)

(optional)
@callback fft2(out :: tensor(), tensor(), keyword()) :: tensor()

fft(out, tensor, keyword)

@callback fft(out :: tensor(), tensor(), keyword()) :: tensor()

floor(out, tensor)

@callback floor(out :: tensor(), tensor()) :: tensor()

from_binary(out, binary, backend_options)

@callback from_binary(out :: tensor(), binary(), backend_options()) :: tensor()

from_pointer(opaque_pointer, type, shape, backend_opts, opts)

@callback from_pointer(
  opaque_pointer :: term(),
  type :: tuple(),
  shape :: tuple(),
  backend_opts :: keyword(),
  opts :: keyword()
) :: tensor() | no_return()

gather(out, input, indices, keyword)

@callback gather(out :: tensor(), input :: tensor(), indices :: tensor(), keyword()) ::
  tensor()

greater(out, tensor, tensor)

@callback greater(out :: tensor(), tensor(), tensor()) :: tensor()

greater_equal(out, tensor, tensor)

@callback greater_equal(out :: tensor(), tensor(), tensor()) :: tensor()

ifft2(out, tensor, keyword)

(optional)
@callback ifft2(out :: tensor(), tensor(), keyword()) :: tensor()

ifft(out, tensor, keyword)

@callback ifft(out :: tensor(), tensor(), keyword()) :: tensor()

imag(out, tensor)

@callback imag(out :: tensor(), tensor()) :: tensor()

indexed_add(out, tensor, indices, updates, keyword)

@callback indexed_add(
  out :: tensor(),
  tensor(),
  indices :: tensor(),
  updates :: tensor(),
  keyword()
) ::
  tensor()

indexed_put(out, tensor, indices, updates, keyword)

@callback indexed_put(
  out :: tensor(),
  tensor(),
  indices :: tensor(),
  updates :: tensor(),
  keyword()
) ::
  tensor()

init(keyword)

@callback init(keyword()) :: backend_options()

inspect(tensor, t)

@callback inspect(tensor(), Inspect.Opts.t()) :: tensor()

iota(tensor, arg2, backend_options)

@callback iota(tensor(), axis() | nil, backend_options()) :: tensor()

is_infinity(out, tensor)

@callback is_infinity(out :: tensor(), tensor()) :: tensor()

is_nan(out, tensor)

@callback is_nan(out :: tensor(), tensor()) :: tensor()

left_shift(out, tensor, tensor)

@callback left_shift(out :: tensor(), tensor(), tensor()) :: tensor()

less(out, tensor, tensor)

@callback less(out :: tensor(), tensor(), tensor()) :: tensor()

less_equal(out, tensor, tensor)

@callback less_equal(out :: tensor(), tensor(), tensor()) :: tensor()

log1p(out, tensor)

@callback log1p(out :: tensor(), tensor()) :: tensor()

log(out, tensor)

@callback log(out :: tensor(), tensor()) :: tensor()

logical_and(out, tensor, tensor)

@callback logical_and(out :: tensor(), tensor(), tensor()) :: tensor()

logical_not(out, t)

(optional)
@callback logical_not(out :: tensor(), t :: tensor()) :: tensor()

logical_or(out, tensor, tensor)

@callback logical_or(out :: tensor(), tensor(), tensor()) :: tensor()

logical_xor(out, tensor, tensor)

@callback logical_xor(out :: tensor(), tensor(), tensor()) :: tensor()

lu({}, tensor, keyword)

(optional)
@callback lu({p :: tensor(), l :: tensor(), u :: tensor()}, tensor(), keyword()) ::
  tensor()

max(out, tensor, tensor)

@callback max(out :: tensor(), tensor(), tensor()) :: tensor()

min(out, tensor, tensor)

@callback min(out :: tensor(), tensor(), tensor()) :: tensor()

multiply(out, tensor, tensor)

@callback multiply(out :: tensor(), tensor(), tensor()) :: tensor()

negate(out, tensor)

@callback negate(out :: tensor(), tensor()) :: tensor()

not_equal(out, tensor, tensor)

@callback not_equal(out :: tensor(), tensor(), tensor()) :: tensor()

optional(atom, list, fun)

(optional)
@callback optional(atom(), [term()], fun()) :: tensor()

Invoked for execution of optional callbacks with a default implementation.

First we will attempt to call the optional callback itself (one of the many callbacks defined below), then we attempt to call this callback (which is also optional), then we fallback to the default implementation.

pad(out, tensor, pad_value, padding_config)

@callback pad(out :: tensor(), tensor(), pad_value :: tensor(), padding_config :: list()) ::
  tensor()

phase(out, t)

(optional)
@callback phase(out :: tensor(), t :: tensor()) :: tensor()

population_count(out, tensor)

@callback population_count(out :: tensor(), tensor()) :: tensor()

pow(out, tensor, tensor)

@callback pow(out :: tensor(), tensor(), tensor()) :: tensor()

product(out, tensor, keyword)

@callback product(out :: tensor(), tensor(), keyword()) :: tensor()

put_slice(out, tensor, tensor, list)

@callback put_slice(out :: tensor(), tensor(), tensor(), list()) :: tensor()

qr({}, tensor, keyword)

(optional)
@callback qr({q :: tensor(), r :: tensor()}, tensor(), keyword()) :: tensor()

quotient(out, tensor, tensor)

@callback quotient(out :: tensor(), tensor(), tensor()) :: tensor()

real(out, tensor)

@callback real(out :: tensor(), tensor()) :: tensor()

reduce(out, tensor, acc, keyword, fun)

@callback reduce(out :: tensor(), tensor(), acc :: tensor(), keyword(), fun()) :: tensor()

reduce_max(out, tensor, keyword)

@callback reduce_max(out :: tensor(), tensor(), keyword()) :: tensor()

reduce_min(out, tensor, keyword)

@callback reduce_min(out :: tensor(), tensor(), keyword()) :: tensor()

remainder(out, tensor, tensor)

@callback remainder(out :: tensor(), tensor(), tensor()) :: tensor()

reshape(out, tensor)

@callback reshape(out :: tensor(), tensor()) :: tensor()

reverse(out, tensor, axes)

@callback reverse(out :: tensor(), tensor(), axes()) :: tensor()

right_shift(out, tensor, tensor)

@callback right_shift(out :: tensor(), tensor(), tensor()) :: tensor()

round(out, tensor)

@callback round(out :: tensor(), tensor()) :: tensor()

rsqrt(out, tensor)

@callback rsqrt(out :: tensor(), tensor()) :: tensor()

select(out, tensor, tensor, tensor)

@callback select(out :: tensor(), tensor(), tensor(), tensor()) :: tensor()

sigmoid(out, tensor)

@callback sigmoid(out :: tensor(), tensor()) :: tensor()

sign(out, tensor)

@callback sign(out :: tensor(), tensor()) :: tensor()

sin(out, tensor)

@callback sin(out :: tensor(), tensor()) :: tensor()

sinh(out, tensor)

@callback sinh(out :: tensor(), tensor()) :: tensor()

slice(out, tensor, list, list, list)

@callback slice(out :: tensor(), tensor(), list(), list(), list()) :: tensor()

solve(out, a, b)

(optional)
@callback solve(out :: tensor(), a :: tensor(), b :: tensor()) :: tensor()

sort(out, tensor, keyword)

@callback sort(out :: tensor(), tensor(), keyword()) :: tensor()

sqrt(out, tensor)

@callback sqrt(out :: tensor(), tensor()) :: tensor()

squeeze(out, tensor, axes)

@callback squeeze(out :: tensor(), tensor(), axes()) :: tensor()

stack(out, tensor, axis)

@callback stack(out :: tensor(), tensor(), axis()) :: tensor()

subtract(out, tensor, tensor)

@callback subtract(out :: tensor(), tensor(), tensor()) :: tensor()

sum(out, tensor, keyword)

@callback sum(out :: tensor(), tensor(), keyword()) :: tensor()

svd({}, tensor, keyword)

(optional)
@callback svd({u :: tensor(), s :: tensor(), v :: tensor()}, tensor(), keyword()) ::
  tensor()

take(out, input, indices, keyword)

(optional)
@callback take(out :: tensor(), input :: tensor(), indices :: tensor(), keyword()) ::
  tensor()

take_along_axis(out, input, indices, keyword)

(optional)
@callback take_along_axis(
  out :: tensor(),
  input :: tensor(),
  indices :: tensor(),
  keyword()
) :: tensor()

tan(out, tensor)

@callback tan(out :: tensor(), tensor()) :: tensor()

tanh(out, tensor)

@callback tanh(out :: tensor(), tensor()) :: tensor()

to_batched(out, tensor, keyword)

@callback to_batched(out :: tensor(), tensor(), keyword()) :: [tensor()]

to_binary(tensor, limit)

@callback to_binary(tensor(), limit :: non_neg_integer()) :: binary()

to_pointer(tensor, opts)

@callback to_pointer(tensor(), opts :: keyword()) :: term() | no_return()

top_k(out, tensor, keyword)

(optional)
@callback top_k(out :: tensor(), tensor(), keyword()) :: tensor()

transpose(out, tensor, axes)

@callback transpose(out :: tensor(), tensor(), axes()) :: tensor()

triangular_solve(out, a, b, keyword)

@callback triangular_solve(out :: tensor(), a :: tensor(), b :: tensor(), keyword()) ::
  tensor()

window_max(out, tensor, shape, keyword)

@callback window_max(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

window_min(out, tensor, shape, keyword)

@callback window_min(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

window_product(out, tensor, shape, keyword)

@callback window_product(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

window_reduce(out, tensor, acc, shape, keyword, fun)

@callback window_reduce(
  out :: tensor(),
  tensor(),
  acc :: tensor(),
  shape(),
  keyword(),
  fun()
) :: tensor()

window_scatter_max(out, tensor, tensor, tensor, shape, keyword)

@callback window_scatter_max(
  out :: tensor(),
  tensor(),
  tensor(),
  tensor(),
  shape(),
  keyword()
) :: tensor()

window_scatter_min(out, tensor, tensor, tensor, shape, keyword)

@callback window_scatter_min(
  out :: tensor(),
  tensor(),
  tensor(),
  tensor(),
  shape(),
  keyword()
) :: tensor()

window_sum(out, tensor, shape, keyword)

@callback window_sum(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

Functions

inspect(map, binary, inspect_opts)

Inspects the given tensor given by binary.

Note the binary may have fewer elements than the tensor size but, in such cases, it must strictly have more elements than inspect_opts.limit.