View Source Nx.Backend behaviour (Nx v0.9.0)

The behaviour for tensor backends.

Each backend is module that defines a struct and implements the callbacks defined in this module. The callbacks are mostly implementations of the functions in the Nx module with the tensor output shape given as first argument.

Nx backends come in two flavors: opaque backends, of which you should not access its data directly except through the functions in the Nx module, and public ones, of which its data can be directly accessed and traversed. The former typically have the Backend suffix.

Nx ships with the following backends:

  • Nx.BinaryBackend - an opaque backend written in pure Elixir that stores the data in Elixir's binaries. This is the default backend used by the Nx module. The backend itself (and its data) is private and must not be accessed directly.

  • Nx.TemplateBackend - an opaque backend written that works as a template in APIs to declare the type, shape, and names of tensors to be expected in the future.

  • Nx.Defn.Expr - a public backend used by defn to build expression trees that are traversed by custom compilers.

This module also includes functions that are meant to be shared across backends.

Summary

Callbacks

Invoked for execution of optional callbacks with a default implementation.

Types

@type axes() :: Nx.Tensor.axes()
@type axis() :: Nx.Tensor.axis()
@type backend_options() :: term()
@type shape() :: Nx.Tensor.shape()
@type t() :: %{__struct__: atom()}
@type tensor() :: Nx.Tensor.t()

Callbacks

@callback abs(out :: tensor(), tensor()) :: tensor()
@callback acos(out :: tensor(), tensor()) :: tensor()
@callback acosh(out :: tensor(), tensor()) :: tensor()
Link to this callback

add(out, tensor, tensor)

View Source
@callback add(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

all(out, tensor, keyword)

View Source
@callback all(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

all_close(out, tensor, tensor, keyword)

View Source (optional)
@callback all_close(out :: tensor(), tensor(), tensor(), keyword()) :: tensor()
Link to this callback

any(out, tensor, keyword)

View Source
@callback any(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

argmax(out, tensor, keyword)

View Source
@callback argmax(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

argmin(out, tensor, keyword)

View Source
@callback argmin(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

argsort(out, tensor, keyword)

View Source
@callback argsort(out :: tensor(), tensor(), keyword()) :: tensor()
@callback as_type(out :: tensor(), tensor()) :: tensor()
@callback asin(out :: tensor(), tensor()) :: tensor()
@callback asinh(out :: tensor(), tensor()) :: tensor()
Link to this callback

atan2(out, tensor, tensor)

View Source
@callback atan2(out :: tensor(), tensor(), tensor()) :: tensor()
@callback atan(out :: tensor(), tensor()) :: tensor()
@callback atanh(out :: tensor(), tensor()) :: tensor()
Link to this callback

backend_copy(tensor, module, backend_options)

View Source
@callback backend_copy(tensor(), module(), backend_options()) :: tensor()
Link to this callback

backend_deallocate(tensor)

View Source
@callback backend_deallocate(tensor()) :: :ok | :already_deallocated
Link to this callback

backend_transfer(tensor, module, backend_options)

View Source
@callback backend_transfer(tensor(), module(), backend_options()) :: tensor()
@callback bitcast(out :: tensor(), tensor()) :: tensor()
Link to this callback

bitwise_and(out, tensor, tensor)

View Source
@callback bitwise_and(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

bitwise_not(out, tensor)

View Source
@callback bitwise_not(out :: tensor(), tensor()) :: tensor()
Link to this callback

bitwise_or(out, tensor, tensor)

View Source
@callback bitwise_or(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

bitwise_xor(out, tensor, tensor)

View Source
@callback bitwise_xor(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

broadcast(out, tensor, shape, axes)

View Source
@callback broadcast(out :: tensor(), tensor(), shape(), axes()) :: tensor()
@callback cbrt(out :: tensor(), tensor()) :: tensor()
@callback ceil(out :: tensor(), tensor()) :: tensor()
Link to this callback

cholesky(out, tensor)

View Source (optional)
@callback cholesky(out :: tensor(), tensor()) :: tensor()
Link to this callback

clip(out, tensor, min, max)

View Source
@callback clip(out :: tensor(), tensor(), min :: tensor(), max :: tensor()) :: tensor()
Link to this callback

concatenate(out, tensor, axis)

View Source
@callback concatenate(out :: tensor(), tensor(), axis()) :: tensor()
@callback conjugate(out :: tensor(), tensor()) :: tensor()
Link to this callback

constant(out, arg2, backend_options)

View Source
@callback constant(out :: tensor(), number() | Complex.t(), backend_options()) :: tensor()
Link to this callback

conv(out, tensor, kernel, keyword)

View Source
@callback conv(out :: tensor(), tensor(), kernel :: tensor(), keyword()) :: tensor()
@callback cos(out :: tensor(), tensor()) :: tensor()
@callback cosh(out :: tensor(), tensor()) :: tensor()
Link to this callback

count_leading_zeros(out, tensor)

View Source
@callback count_leading_zeros(out :: tensor(), tensor()) :: tensor()
Link to this callback

cumulative_max(out, t, keyword)

View Source (optional)
@callback cumulative_max(out :: tensor(), t :: tensor(), keyword()) :: tensor()
Link to this callback

cumulative_min(out, t, keyword)

View Source (optional)
@callback cumulative_min(out :: tensor(), t :: tensor(), keyword()) :: tensor()
Link to this callback

cumulative_product(out, t, keyword)

View Source (optional)
@callback cumulative_product(out :: tensor(), t :: tensor(), keyword()) :: tensor()
Link to this callback

cumulative_sum(out, t, keyword)

View Source (optional)
@callback cumulative_sum(out :: tensor(), t :: tensor(), keyword()) :: tensor()
Link to this callback

determinant(out, t)

View Source (optional)
@callback determinant(out :: tensor(), t :: tensor()) :: tensor()
Link to this callback

divide(out, tensor, tensor)

View Source
@callback divide(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

dot(out, tensor, axes, axes, tensor, axes, axes)

View Source
@callback dot(out :: tensor(), tensor(), axes(), axes(), tensor(), axes(), axes()) ::
  tensor()
Link to this callback

eigh({}, tensor, keyword)

View Source (optional)
@callback eigh({eigenvals :: tensor(), eigenvecs :: tensor()}, tensor(), keyword()) ::
  tensor()
Link to this callback

equal(out, tensor, tensor)

View Source
@callback equal(out :: tensor(), tensor(), tensor()) :: tensor()
@callback erf(out :: tensor(), tensor()) :: tensor()
@callback erf_inv(out :: tensor(), tensor()) :: tensor()
@callback erfc(out :: tensor(), tensor()) :: tensor()
@callback exp(out :: tensor(), tensor()) :: tensor()
@callback expm1(out :: tensor(), tensor()) :: tensor()
Link to this callback

eye(tensor, backend_options)

View Source
@callback eye(tensor(), backend_options()) :: tensor()
Link to this callback

fft2(out, tensor, keyword)

View Source (optional)
@callback fft2(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

fft(out, tensor, keyword)

View Source
@callback fft(out :: tensor(), tensor(), keyword()) :: tensor()
@callback floor(out :: tensor(), tensor()) :: tensor()
Link to this callback

from_binary(out, binary, backend_options)

View Source
@callback from_binary(out :: tensor(), binary(), backend_options()) :: tensor()
Link to this callback

from_pointer(opaque_pointer, type, shape, backend_opts, opts)

View Source
@callback from_pointer(
  opaque_pointer :: term(),
  type :: tuple(),
  shape :: tuple(),
  backend_opts :: keyword(),
  opts :: keyword()
) :: {:ok, tensor()} | {:error, term()}
Link to this callback

gather(out, input, indices, keyword)

View Source
@callback gather(out :: tensor(), input :: tensor(), indices :: tensor(), keyword()) ::
  tensor()
Link to this callback

greater(out, tensor, tensor)

View Source
@callback greater(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

greater_equal(out, tensor, tensor)

View Source
@callback greater_equal(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

ifft2(out, tensor, keyword)

View Source (optional)
@callback ifft2(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

ifft(out, tensor, keyword)

View Source
@callback ifft(out :: tensor(), tensor(), keyword()) :: tensor()
@callback imag(out :: tensor(), tensor()) :: tensor()
Link to this callback

indexed_add(out, tensor, indices, updates, keyword)

View Source
@callback indexed_add(
  out :: tensor(),
  tensor(),
  indices :: tensor(),
  updates :: tensor(),
  keyword()
) ::
  tensor()
Link to this callback

indexed_put(out, tensor, indices, updates, keyword)

View Source
@callback indexed_put(
  out :: tensor(),
  tensor(),
  indices :: tensor(),
  updates :: tensor(),
  keyword()
) ::
  tensor()
@callback init(keyword()) :: backend_options()
@callback inspect(tensor(), Inspect.Opts.t()) :: tensor()
Link to this callback

iota(tensor, arg2, backend_options)

View Source
@callback iota(tensor(), axis() | nil, backend_options()) :: tensor()
Link to this callback

is_infinity(out, tensor)

View Source
@callback is_infinity(out :: tensor(), tensor()) :: tensor()
@callback is_nan(out :: tensor(), tensor()) :: tensor()
Link to this callback

left_shift(out, tensor, tensor)

View Source
@callback left_shift(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

less(out, tensor, tensor)

View Source
@callback less(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

less_equal(out, tensor, tensor)

View Source
@callback less_equal(out :: tensor(), tensor(), tensor()) :: tensor()
@callback log1p(out :: tensor(), tensor()) :: tensor()
@callback log(out :: tensor(), tensor()) :: tensor()
Link to this callback

logical_and(out, tensor, tensor)

View Source
@callback logical_and(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

logical_not(out, t)

View Source (optional)
@callback logical_not(out :: tensor(), t :: tensor()) :: tensor()
Link to this callback

logical_or(out, tensor, tensor)

View Source
@callback logical_or(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

logical_xor(out, tensor, tensor)

View Source
@callback logical_xor(out :: tensor(), tensor(), tensor()) :: tensor()
@callback lu({p :: tensor(), l :: tensor(), u :: tensor()}, tensor(), keyword()) ::
  tensor()
Link to this callback

max(out, tensor, tensor)

View Source
@callback max(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

min(out, tensor, tensor)

View Source
@callback min(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

multiply(out, tensor, tensor)

View Source
@callback multiply(out :: tensor(), tensor(), tensor()) :: tensor()
@callback negate(out :: tensor(), tensor()) :: tensor()
Link to this callback

not_equal(out, tensor, tensor)

View Source
@callback not_equal(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

optional(atom, list, function)

View Source (optional)
@callback optional(atom(), [term()], (... -> any())) :: tensor()

Invoked for execution of optional callbacks with a default implementation.

First we will attempt to call the optional callback itself (one of the many callbacks defined below), then we attempt to call this callback (which is also optional), then we fallback to the default iomplementation.

Link to this callback

pad(out, tensor, pad_value, padding_config)

View Source
@callback pad(out :: tensor(), tensor(), pad_value :: tensor(), padding_config :: list()) ::
  tensor()
Link to this callback

phase(out, t)

View Source (optional)
@callback phase(out :: tensor(), t :: tensor()) :: tensor()
Link to this callback

population_count(out, tensor)

View Source
@callback population_count(out :: tensor(), tensor()) :: tensor()
Link to this callback

pow(out, tensor, tensor)

View Source
@callback pow(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

product(out, tensor, keyword)

View Source
@callback product(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

put_slice(out, tensor, tensor, list)

View Source
@callback put_slice(out :: tensor(), tensor(), tensor(), list()) :: tensor()
Link to this callback

qr({}, tensor, keyword)

View Source (optional)
@callback qr({q :: tensor(), r :: tensor()}, tensor(), keyword()) :: tensor()
Link to this callback

quotient(out, tensor, tensor)

View Source
@callback quotient(out :: tensor(), tensor(), tensor()) :: tensor()
@callback real(out :: tensor(), tensor()) :: tensor()
Link to this callback

reduce(out, tensor, acc, keyword, function)

View Source
@callback reduce(out :: tensor(), tensor(), acc :: tensor(), keyword(), (... -> any())) ::
  tensor()
Link to this callback

reduce_max(out, tensor, keyword)

View Source
@callback reduce_max(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

reduce_min(out, tensor, keyword)

View Source
@callback reduce_min(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

remainder(out, tensor, tensor)

View Source
@callback remainder(out :: tensor(), tensor(), tensor()) :: tensor()
@callback reshape(out :: tensor(), tensor()) :: tensor()
Link to this callback

reverse(out, tensor, axes)

View Source
@callback reverse(out :: tensor(), tensor(), axes()) :: tensor()
Link to this callback

right_shift(out, tensor, tensor)

View Source
@callback right_shift(out :: tensor(), tensor(), tensor()) :: tensor()
@callback round(out :: tensor(), tensor()) :: tensor()
@callback rsqrt(out :: tensor(), tensor()) :: tensor()
Link to this callback

select(out, tensor, tensor, tensor)

View Source
@callback select(out :: tensor(), tensor(), tensor(), tensor()) :: tensor()
@callback sigmoid(out :: tensor(), tensor()) :: tensor()
@callback sign(out :: tensor(), tensor()) :: tensor()
@callback sin(out :: tensor(), tensor()) :: tensor()
@callback sinh(out :: tensor(), tensor()) :: tensor()
Link to this callback

slice(out, tensor, list, list, list)

View Source
@callback slice(out :: tensor(), tensor(), list(), list(), list()) :: tensor()
Link to this callback

solve(out, a, b)

View Source (optional)
@callback solve(out :: tensor(), a :: tensor(), b :: tensor()) :: tensor()
Link to this callback

sort(out, tensor, keyword)

View Source
@callback sort(out :: tensor(), tensor(), keyword()) :: tensor()
@callback sqrt(out :: tensor(), tensor()) :: tensor()
Link to this callback

squeeze(out, tensor, axes)

View Source
@callback squeeze(out :: tensor(), tensor(), axes()) :: tensor()
Link to this callback

stack(out, tensor, axis)

View Source
@callback stack(out :: tensor(), tensor(), axis()) :: tensor()
Link to this callback

subtract(out, tensor, tensor)

View Source
@callback subtract(out :: tensor(), tensor(), tensor()) :: tensor()
Link to this callback

sum(out, tensor, keyword)

View Source
@callback sum(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

svd({}, tensor, keyword)

View Source (optional)
@callback svd({u :: tensor(), s :: tensor(), v :: tensor()}, tensor(), keyword()) ::
  tensor()
Link to this callback

take(out, input, indices, keyword)

View Source (optional)
@callback take(out :: tensor(), input :: tensor(), indices :: tensor(), keyword()) ::
  tensor()
Link to this callback

take_along_axis(out, input, indices, keyword)

View Source (optional)
@callback take_along_axis(
  out :: tensor(),
  input :: tensor(),
  indices :: tensor(),
  keyword()
) :: tensor()
@callback tan(out :: tensor(), tensor()) :: tensor()
@callback tanh(out :: tensor(), tensor()) :: tensor()
Link to this callback

to_batched(out, tensor, keyword)

View Source
@callback to_batched(out :: tensor(), tensor(), keyword()) :: [tensor()]
Link to this callback

to_binary(tensor, limit)

View Source
@callback to_binary(tensor(), limit :: non_neg_integer()) :: binary()
Link to this callback

to_pointer(tensor, opts)

View Source
@callback to_pointer(tensor(), opts :: keyword()) :: {:ok, term()} | {:error, term()}
Link to this callback

top_k(out, tensor, keyword)

View Source (optional)
@callback top_k(out :: tensor(), tensor(), keyword()) :: tensor()
Link to this callback

transpose(out, tensor, axes)

View Source
@callback transpose(out :: tensor(), tensor(), axes()) :: tensor()
Link to this callback

triangular_solve(out, a, b, keyword)

View Source
@callback triangular_solve(out :: tensor(), a :: tensor(), b :: tensor(), keyword()) ::
  tensor()
Link to this callback

window_max(out, tensor, shape, keyword)

View Source
@callback window_max(out :: tensor(), tensor(), shape(), keyword()) :: tensor()
Link to this callback

window_min(out, tensor, shape, keyword)

View Source
@callback window_min(out :: tensor(), tensor(), shape(), keyword()) :: tensor()
Link to this callback

window_product(out, tensor, shape, keyword)

View Source
@callback window_product(out :: tensor(), tensor(), shape(), keyword()) :: tensor()
Link to this callback

window_reduce(out, tensor, acc, shape, keyword, function)

View Source
@callback window_reduce(
  out :: tensor(),
  tensor(),
  acc :: tensor(),
  shape(),
  keyword(),
  (... -> any())
) ::
  tensor()
Link to this callback

window_scatter_max(out, tensor, tensor, tensor, shape, keyword)

View Source
@callback window_scatter_max(
  out :: tensor(),
  tensor(),
  tensor(),
  tensor(),
  shape(),
  keyword()
) :: tensor()
Link to this callback

window_scatter_min(out, tensor, tensor, tensor, shape, keyword)

View Source
@callback window_scatter_min(
  out :: tensor(),
  tensor(),
  tensor(),
  tensor(),
  shape(),
  keyword()
) :: tensor()
Link to this callback

window_sum(out, tensor, shape, keyword)

View Source
@callback window_sum(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

Functions

Link to this function

complex_to_string(complex, precision)

View Source
Link to this function

inspect(map, binary, inspect_opts)

View Source

Inspects the given tensor given by binary.

Note the binary may have fewer elements than the tensor size but, in such cases, it must strictly have more elements than inspect_opts.limit

Options

The following must be passed through Inspect :custom_options

  • :nx_precision - Configures the floating-point number printing precision. If set, will print floating-point numbers in scientific notation using the specified number of significant digits. Otherwise, default Elixir printing rules are applied.