Nx.Backend behaviour (Nx v0.12.0)

Copy Markdown View Source

The behaviour for tensor backends.

Each backend is module that defines a struct and implements the callbacks defined in this module. The callbacks are mostly implementations of the functions in the Nx module with the tensor output shape given as first argument.

Nx backends come in two flavors: opaque backends, of which you should not access its data directly except through the functions in the Nx module, and public ones, of which its data can be directly accessed and traversed. The former typically have the Backend suffix.

Nx ships with the following backends:

  • Nx.BinaryBackend - an opaque backend written in pure Elixir that stores the data in Elixir's binaries. This is the default backend used by the Nx module. The backend itself (and its data) is private and must not be accessed directly.

  • Nx.TemplateBackend - an opaque backend written that works as a template in APIs to declare the type, shape, and names of tensors to be expected in the future.

  • Nx.Defn.Expr - a public backend used by defn to build expression trees that are traversed by custom compilers.

This module also includes functions that are meant to be shared across backends.

Summary

Callbacks

Invoked for execution of Nx.block/4.

Functions

Inspects the given tensor given by binary.

Types

axes()

@type axes() :: Nx.Tensor.axes()

axis()

@type axis() :: Nx.Tensor.axis()

backend_options()

@type backend_options() :: term()

shape()

@type shape() :: Nx.Tensor.shape()

t()

@type t() :: %{__struct__: atom()}

tensor()

@type tensor() :: Nx.Tensor.t()

Callbacks

abs(out, tensor)

@callback abs(out :: tensor(), tensor()) :: tensor()

acos(out, tensor)

@callback acos(out :: tensor(), tensor()) :: tensor()

acosh(out, tensor)

@callback acosh(out :: tensor(), tensor()) :: tensor()

add(out, tensor, tensor)

@callback add(out :: tensor(), tensor(), tensor()) :: tensor()

all(out, tensor, keyword)

@callback all(out :: tensor(), tensor(), keyword()) :: tensor()

any(out, tensor, keyword)

@callback any(out :: tensor(), tensor(), keyword()) :: tensor()

argmax(out, tensor, keyword)

@callback argmax(out :: tensor(), tensor(), keyword()) :: tensor()

argmin(out, tensor, keyword)

@callback argmin(out :: tensor(), tensor(), keyword()) :: tensor()

argsort(out, tensor, keyword)

@callback argsort(out :: tensor(), tensor(), keyword()) :: tensor()

as_type(out, tensor)

@callback as_type(out :: tensor(), tensor()) :: tensor()

asin(out, tensor)

@callback asin(out :: tensor(), tensor()) :: tensor()

asinh(out, tensor)

@callback asinh(out :: tensor(), tensor()) :: tensor()

atan2(out, tensor, tensor)

@callback atan2(out :: tensor(), tensor(), tensor()) :: tensor()

atan(out, tensor)

@callback atan(out :: tensor(), tensor()) :: tensor()

atanh(out, tensor)

@callback atanh(out :: tensor(), tensor()) :: tensor()

backend_copy(tensor, module, backend_options)

@callback backend_copy(tensor(), module(), backend_options()) :: tensor()

backend_deallocate(tensor)

@callback backend_deallocate(tensor()) :: :ok | :already_deallocated

backend_transfer(tensor, module, backend_options)

@callback backend_transfer(tensor(), module(), backend_options()) :: tensor()

bitcast(out, tensor)

@callback bitcast(out :: tensor(), tensor()) :: tensor()

bitwise_and(out, tensor, tensor)

@callback bitwise_and(out :: tensor(), tensor(), tensor()) :: tensor()

bitwise_not(out, tensor)

@callback bitwise_not(out :: tensor(), tensor()) :: tensor()

bitwise_or(out, tensor, tensor)

@callback bitwise_or(out :: tensor(), tensor(), tensor()) :: tensor()

bitwise_xor(out, tensor, tensor)

@callback bitwise_xor(out :: tensor(), tensor(), tensor()) :: tensor()

block(struct, output, args, fun)

@callback block(struct(), output :: tensor() | tuple(), args :: [term()], fun()) ::
  tensor() | tuple()

Invoked for execution of Nx.block/4.

output is the result template (Nx.Tensor or tuple of tensors). args are the tensor (and optional trailing keyword lists) passed to Nx.block/4. Backends should dispatch on struct (see Nx.Block.*) and either run a native implementation or invoke fun as apply(fun, [struct | args]).

broadcast(out, tensor, shape, axes)

@callback broadcast(out :: tensor(), tensor(), shape(), axes()) :: tensor()

cbrt(out, tensor)

@callback cbrt(out :: tensor(), tensor()) :: tensor()

ceil(out, tensor)

@callback ceil(out :: tensor(), tensor()) :: tensor()

clip(out, tensor, min, max)

@callback clip(out :: tensor(), tensor(), min :: tensor(), max :: tensor()) :: tensor()

concatenate(out, tensor, axis)

@callback concatenate(out :: tensor(), tensor(), axis()) :: tensor()

conjugate(out, tensor)

@callback conjugate(out :: tensor(), tensor()) :: tensor()

constant(out, arg2, backend_options)

@callback constant(out :: tensor(), number() | Complex.t(), backend_options()) :: tensor()

conv(out, tensor, kernel, keyword)

@callback conv(out :: tensor(), tensor(), kernel :: tensor(), keyword()) :: tensor()

cos(out, tensor)

@callback cos(out :: tensor(), tensor()) :: tensor()

cosh(out, tensor)

@callback cosh(out :: tensor(), tensor()) :: tensor()

count_leading_zeros(out, tensor)

@callback count_leading_zeros(out :: tensor(), tensor()) :: tensor()

divide(out, tensor, tensor)

@callback divide(out :: tensor(), tensor(), tensor()) :: tensor()

dot(out, tensor, axes, axes, tensor, axes, axes)

@callback dot(out :: tensor(), tensor(), axes(), axes(), tensor(), axes(), axes()) ::
  tensor()

equal(out, tensor, tensor)

@callback equal(out :: tensor(), tensor(), tensor()) :: tensor()

erf(out, tensor)

@callback erf(out :: tensor(), tensor()) :: tensor()

erf_inv(out, tensor)

@callback erf_inv(out :: tensor(), tensor()) :: tensor()

erfc(out, tensor)

@callback erfc(out :: tensor(), tensor()) :: tensor()

exp(out, tensor)

@callback exp(out :: tensor(), tensor()) :: tensor()

expm1(out, tensor)

@callback expm1(out :: tensor(), tensor()) :: tensor()

eye(tensor, backend_options)

@callback eye(tensor(), backend_options()) :: tensor()

fft(out, tensor, keyword)

@callback fft(out :: tensor(), tensor(), keyword()) :: tensor()

floor(out, tensor)

@callback floor(out :: tensor(), tensor()) :: tensor()

from_binary(out, binary, backend_options)

@callback from_binary(out :: tensor(), binary(), backend_options()) :: tensor()

from_pointer(opaque_pointer, type, shape, backend_opts, opts)

@callback from_pointer(
  opaque_pointer :: term(),
  type :: tuple(),
  shape :: tuple(),
  backend_opts :: keyword(),
  opts :: keyword()
) :: tensor() | no_return()

gather(out, input, indices, keyword)

@callback gather(out :: tensor(), input :: tensor(), indices :: tensor(), keyword()) ::
  tensor()

greater(out, tensor, tensor)

@callback greater(out :: tensor(), tensor(), tensor()) :: tensor()

greater_equal(out, tensor, tensor)

@callback greater_equal(out :: tensor(), tensor(), tensor()) :: tensor()

ifft(out, tensor, keyword)

@callback ifft(out :: tensor(), tensor(), keyword()) :: tensor()

imag(out, tensor)

@callback imag(out :: tensor(), tensor()) :: tensor()

indexed_add(out, tensor, indices, updates, keyword)

@callback indexed_add(
  out :: tensor(),
  tensor(),
  indices :: tensor(),
  updates :: tensor(),
  keyword()
) ::
  tensor()

indexed_put(out, tensor, indices, updates, keyword)

@callback indexed_put(
  out :: tensor(),
  tensor(),
  indices :: tensor(),
  updates :: tensor(),
  keyword()
) ::
  tensor()

init(keyword)

@callback init(keyword()) :: backend_options()

inspect(tensor, t)

@callback inspect(tensor(), Inspect.Opts.t()) :: tensor()

iota(tensor, arg2, backend_options)

@callback iota(tensor(), axis() | nil, backend_options()) :: tensor()

is_infinity(out, tensor)

@callback is_infinity(out :: tensor(), tensor()) :: tensor()

is_nan(out, tensor)

@callback is_nan(out :: tensor(), tensor()) :: tensor()

left_shift(out, tensor, tensor)

@callback left_shift(out :: tensor(), tensor(), tensor()) :: tensor()

less(out, tensor, tensor)

@callback less(out :: tensor(), tensor(), tensor()) :: tensor()

less_equal(out, tensor, tensor)

@callback less_equal(out :: tensor(), tensor(), tensor()) :: tensor()

log1p(out, tensor)

@callback log1p(out :: tensor(), tensor()) :: tensor()

log(out, tensor)

@callback log(out :: tensor(), tensor()) :: tensor()

logical_and(out, tensor, tensor)

@callback logical_and(out :: tensor(), tensor(), tensor()) :: tensor()

logical_or(out, tensor, tensor)

@callback logical_or(out :: tensor(), tensor(), tensor()) :: tensor()

logical_xor(out, tensor, tensor)

@callback logical_xor(out :: tensor(), tensor(), tensor()) :: tensor()

max(out, tensor, tensor)

@callback max(out :: tensor(), tensor(), tensor()) :: tensor()

min(out, tensor, tensor)

@callback min(out :: tensor(), tensor(), tensor()) :: tensor()

multiply(out, tensor, tensor)

@callback multiply(out :: tensor(), tensor(), tensor()) :: tensor()

negate(out, tensor)

@callback negate(out :: tensor(), tensor()) :: tensor()

not_equal(out, tensor, tensor)

@callback not_equal(out :: tensor(), tensor(), tensor()) :: tensor()

pad(out, tensor, pad_value, padding_config)

@callback pad(out :: tensor(), tensor(), pad_value :: tensor(), padding_config :: list()) ::
  tensor()

population_count(out, tensor)

@callback population_count(out :: tensor(), tensor()) :: tensor()

pow(out, tensor, tensor)

@callback pow(out :: tensor(), tensor(), tensor()) :: tensor()

product(out, tensor, keyword)

@callback product(out :: tensor(), tensor(), keyword()) :: tensor()

put_slice(out, tensor, tensor, list)

@callback put_slice(out :: tensor(), tensor(), tensor(), list()) :: tensor()

quotient(out, tensor, tensor)

@callback quotient(out :: tensor(), tensor(), tensor()) :: tensor()

real(out, tensor)

@callback real(out :: tensor(), tensor()) :: tensor()

reduce(out, tensor, acc, keyword, fun)

@callback reduce(out :: tensor(), tensor(), acc :: tensor(), keyword(), fun()) :: tensor()

reduce_max(out, tensor, keyword)

@callback reduce_max(out :: tensor(), tensor(), keyword()) :: tensor()

reduce_min(out, tensor, keyword)

@callback reduce_min(out :: tensor(), tensor(), keyword()) :: tensor()

remainder(out, tensor, tensor)

@callback remainder(out :: tensor(), tensor(), tensor()) :: tensor()

reshape(out, tensor)

@callback reshape(out :: tensor(), tensor()) :: tensor()

reverse(out, tensor, axes)

@callback reverse(out :: tensor(), tensor(), axes()) :: tensor()

right_shift(out, tensor, tensor)

@callback right_shift(out :: tensor(), tensor(), tensor()) :: tensor()

round(out, tensor)

@callback round(out :: tensor(), tensor()) :: tensor()

rsqrt(out, tensor)

@callback rsqrt(out :: tensor(), tensor()) :: tensor()

select(out, tensor, tensor, tensor)

@callback select(out :: tensor(), tensor(), tensor(), tensor()) :: tensor()

sigmoid(out, tensor)

@callback sigmoid(out :: tensor(), tensor()) :: tensor()

sign(out, tensor)

@callback sign(out :: tensor(), tensor()) :: tensor()

sin(out, tensor)

@callback sin(out :: tensor(), tensor()) :: tensor()

sinh(out, tensor)

@callback sinh(out :: tensor(), tensor()) :: tensor()

slice(out, tensor, list, list, list)

@callback slice(out :: tensor(), tensor(), list(), list(), list()) :: tensor()

sort(out, tensor, keyword)

@callback sort(out :: tensor(), tensor(), keyword()) :: tensor()

sqrt(out, tensor)

@callback sqrt(out :: tensor(), tensor()) :: tensor()

squeeze(out, tensor, axes)

@callback squeeze(out :: tensor(), tensor(), axes()) :: tensor()

stack(out, tensor, axis)

@callback stack(out :: tensor(), tensor(), axis()) :: tensor()

subtract(out, tensor, tensor)

@callback subtract(out :: tensor(), tensor(), tensor()) :: tensor()

sum(out, tensor, keyword)

@callback sum(out :: tensor(), tensor(), keyword()) :: tensor()

tan(out, tensor)

@callback tan(out :: tensor(), tensor()) :: tensor()

tanh(out, tensor)

@callback tanh(out :: tensor(), tensor()) :: tensor()

to_batched(out, tensor, keyword)

@callback to_batched(out :: tensor(), tensor(), keyword()) :: [tensor()]

to_binary(tensor, limit)

@callback to_binary(tensor(), limit :: non_neg_integer()) :: binary()

to_pointer(tensor, opts)

@callback to_pointer(tensor(), opts :: keyword()) :: term() | no_return()

transpose(out, tensor, axes)

@callback transpose(out :: tensor(), tensor(), axes()) :: tensor()

triangular_solve(out, a, b, keyword)

@callback triangular_solve(out :: tensor(), a :: tensor(), b :: tensor(), keyword()) ::
  tensor()

window_max(out, tensor, shape, keyword)

@callback window_max(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

window_min(out, tensor, shape, keyword)

@callback window_min(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

window_product(out, tensor, shape, keyword)

@callback window_product(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

window_reduce(out, tensor, acc, shape, keyword, fun)

@callback window_reduce(
  out :: tensor(),
  tensor(),
  acc :: tensor(),
  shape(),
  keyword(),
  fun()
) :: tensor()

window_scatter_max(out, tensor, tensor, tensor, shape, keyword)

@callback window_scatter_max(
  out :: tensor(),
  tensor(),
  tensor(),
  tensor(),
  shape(),
  keyword()
) :: tensor()

window_scatter_min(out, tensor, tensor, tensor, shape, keyword)

@callback window_scatter_min(
  out :: tensor(),
  tensor(),
  tensor(),
  tensor(),
  shape(),
  keyword()
) :: tensor()

window_sum(out, tensor, shape, keyword)

@callback window_sum(out :: tensor(), tensor(), shape(), keyword()) :: tensor()

Functions

inspect(map, binary, inspect_opts)

Inspects the given tensor given by binary.

Note the binary may have fewer elements than the tensor size but, in such cases, it must strictly have more elements than inspect_opts.limit.