View Source Nx Quickstart

Prerequisites

To properly use Nx, you will need to know a bit of Elixir. For a refresher, check out the Elixir Getting Started Guide.

To work on the examples you can run using the "Run in Livebook" button in this page.

Learning Objectives

This is a overview of Nx tensors. In this section, we'll look at some of the various tools for creating and interacting with tensors.

After reading, you should be able to understand:

  • Create 1, 2 and N-dimensional tensors in Nx;
  • How to index, slice and iterate through tensors;
  • Basic tensor functions;
  • How to apply some linear algebra operations to n-dimensional tensors without using for-loops;
  • Axis and shape properties for n-dimensional tensors.

The Basics

First, let's install Nx with Mix.install.

Mix.install([
  {:nx, "~> 0.9"}
])

The IEx.Helpers module will assist our exploration of core tensor concepts.

import IEx.Helpers

Creating tensors

The argument for Nx.tensor/1 must be one of:

  • a tensor;
  • a number (which means the tensor is scalar/zero-dimensional);
  • a boolean (also scalar/zero-dimensional);
  • an arbitrarily nested list of numbers and booleans
  • the special atoms :nan, :infinity, :neg_infinity, which represent non-finite numbers which are not supported by Elixir floats.

If a new tensor is allocated, it will be allocated in the backend defined by the :backend option. If it is not provided, Nx.default_backend/0 will be used instead.

Examples

A number returns a tensor of zero dimensions, also known as a scalar:

Nx.tensor(0)
Nx.tensor(1.0)

A list returns a one-dimensional tensor, also known as a vector:

Nx.tensor([1, 2, 3])
Nx.tensor([1.2, 2.3, 3.4, 4.5])

Higher dimensional tensors are also possible:

Nx.tensor([[1, 2, 3], [4, 5, 6]])
Nx.tensor([[1, 2], [3, 4], [5, 6]])
Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[-1, -2], [-3, -4], [-5, -6]]])

Tensors can also be given as inputs, which is useful for functions that don't care about the input kind:

Nx.tensor(Nx.tensor([1, 2, 3]))

Naming dimensions

You can provide names for tensor dimensions. Names are atoms:

Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])

Names make your code more expressive:

Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, :height, :width])

We created a tensor of the shape {3, 3}, and two axes named height and width.

You can also leave dimension names as nil (which is the default):

Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, nil, nil])

However, you must provide a name for every dimension in the tensor. For example, the following code snippet raises an error because 1 name is given, but there are 3 dimensions:

Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch])

Indexing and Slicing tensor values

We can get any cell of the tensor:

tensor = Nx.tensor([[1, 2], [3, 4]], names: [:y, :x])
tensor[[0, 1]]

Negative indices will start counting from the end of the axis. -1 is the last entry, -2 the second to last and so on.

tensor = Nx.tensor([[1, 2], [3, 4], [5, 6]], names: [:y, :x])
tensor[[-1, -1]]

We can also get a whole dimension:

tensor[x: 1]

or a range:

tensor[y: 0..1]

tensor[[.., 1]] will achieve the same result as tensor[x: 1]. This is because Elixir has the syntax sugar .. for a 0..-1//1 range.

Tensor shape and reshape

Nx.shape(tensor)

We can also create a new tensor with the given shape using Nx.reshape/2:

Nx.reshape(tensor, {1, 4}, names: [:batches, :values])

This operation generally reuses all of the tensor data and simply changes the metadata, so it has no notable cost. The new tensor has the same type, but a new shape.

Floats and Complex numbers

Besides single-precision (32 bits) floats, Nx floating-point numbers can also have other kinds of precision, such as half-precision (16) or double-precision (64):

Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f16)
Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f64)

Brain floats are also supported:

Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :bf16)

Certain backends and compilers support 8-bit floats. The precision implementation of 8-bit floats may change per backend, so you must be careful when transferring data across different backends. Nx.BinaryBackend implements F8E5M2:

Nx.tensor([1, 2, 3], type: :f8)

In all cases, the non-finite values negative infinity (-Inf), infinity (Inf), and "not a number" (NaN) can be represented by the atoms :neg_infinity, :infinity, and :nan, respectively:

Nx.tensor([:neg_infinity, :nan, :infinity])

Finally, complex numbers are also supported in tensors, in both 32-bit and 64-bit precision:

Nx.tensor(Complex.new(1, -1))

Check out the documentation for Nx.tensor/2 for more documentation on the accepted options.

Basic operations

Nx supports element-wise arithmetic operations for tensors and broadcasting when necessary.

Addition

Nx.add/2: Adds corresponding elements of two tensors.

a = Nx.tensor([1, 2, 3])
b = Nx.tensor([0, 1, 2])
Nx.add(a, b)

Subtraction

Nx.subtract/2: Subtracts the elements of the second tensor from the first.

a = Nx.tensor([10, 20, 30])
b = Nx.tensor([0, 1, 2])
Nx.subtract(a, b)

Multiplication

Nx.multiply/2: Multiplies corresponding elements of two tensors.

a = Nx.tensor([2, 3, 4])
b = Nx.tensor([0, 1, 2])
Nx.multiply(a, b)

Division

Nx.divide/2: Divides the elements of the first tensor by the second tensor.

a = Nx.tensor([10, 30, 40])
b = Nx.tensor([5, 6, 8])
Nx.divide(a, b)

Exponentiation

Nx.pow/2: Raises each element of the first tensor to the power of the corresponding element in the second tensor.

a = Nx.tensor([2, 3, 4])
b = Nx.tensor([2])
Nx.pow(a, b)

Quotient

Nx.quotient/2: Returns a new tensor where each element is the integer division (div/2).

a = Nx.tensor([10, 20, 30])
b = Nx.tensor([3, 7, 4])

Nx.quotient(a, b)

Remainder

Nx.remainder/2: Computes the integer division remainder.

a = Nx.tensor([27, 32, 43])
b = Nx.tensor([2, 3, 4])
Nx.remainder(a , b)

Negation

Nx.negate/1: Negates each element of a tensor.

a = Nx.tensor([2, 3, 4])
Nx.negate(a)

Square Root

Nx.sqrt/1: Computes the element-wise square root.

a = Nx.tensor([4, 9, 16])
Nx.sqrt(a)

Element-Wise Comparison

The following operations returns a u8 tensor where 1 represents true and 0 represents false.

Equality and Inequality

Nx.equal/2, Nx.not_equal/2

a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 9, -16])
Nx.equal(a, b)
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.0, 9.0, -16.0])
Nx.not_equal(a, b)

Greater and Less

Nx.greater/2, Nx.less/2

a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 8, 17])
Nx.greater(a, b)
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.2, 9.0, 15.9])
Nx.less(a, b)

Greater_Equal and Less_Equal

Nx.greater_equal/2, Nx.less_equal/2

a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 8, 17])

Nx.greater_equal(a, b)
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.2, 9.0, 15.9])

Nx.less_equal(a, b)

Aggregate functions

These operations aggregate values across tensor axes.

See also the aggregation guide for a more in-depth exploration on the subject.

Sum

Nx.sum/1: Sums all elements.

a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.sum(a)

Mean

Nx.mean/1: Computes the mean value of the elements.

a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.mean(a)

Product

Nx.product/1: Computes the product of the elements.

a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.product(a)

Matrix Multiplication

Nx.dot/4: Computes the generalized dot product between two tensors, operating on specific contracting axes.

t1 = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y])
t2 = Nx.tensor([[10, 20], [30, 40]], names: [:height, :width])
Nx.dot(t1, [0], t2, [0])