View Source Nx Quickstart
Prerequisites
To properly use Nx, you will need to know a bit of Elixir. For a refresher, check out the Elixir Getting Started Guide.
To work on the examples you can run using the "Run in Livebook" button in this page.
Learning Objectives
This is a overview of Nx tensors. In this section, we'll look at some of the various tools for creating and interacting with tensors.
After reading, you should be able to understand:
- Create 1, 2 and N-dimensional tensors in
Nx; - How to index, slice and iterate through tensors;
- Basic tensor functions;
- How to apply some linear algebra operations to n-dimensional tensors without using for-loops;
- Axis and shape properties for n-dimensional tensors.
The Basics
First, let's install Nx with Mix.install.
Mix.install([
{:nx, "~> 0.9"}
])The IEx.Helpers module will assist our exploration of core tensor concepts.
import IEx.HelpersCreating tensors
The argument for Nx.tensor/1 must be one of:
- a tensor;
- a number (which means the tensor is scalar/zero-dimensional);
- a boolean (also scalar/zero-dimensional);
- an arbitrarily nested list of numbers and booleans
- the special atoms
:nan,:infinity,:neg_infinity, which represent non-finite numbers which are not supported by Elixir floats.
If a new tensor is allocated, it will be allocated in the backend defined by the :backend option.
If it is not provided, Nx.default_backend/0 will be used instead.
Examples
A number returns a tensor of zero dimensions, also known as a scalar:
Nx.tensor(0)Nx.tensor(1.0)A list returns a one-dimensional tensor, also known as a vector:
Nx.tensor([1, 2, 3])Nx.tensor([1.2, 2.3, 3.4, 4.5])Higher dimensional tensors are also possible:
Nx.tensor([[1, 2, 3], [4, 5, 6]])Nx.tensor([[1, 2], [3, 4], [5, 6]])Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[-1, -2], [-3, -4], [-5, -6]]])Tensors can also be given as inputs, which is useful for functions that don't care about the input kind:
Nx.tensor(Nx.tensor([1, 2, 3]))Naming dimensions
You can provide names for tensor dimensions. Names are atoms:
Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])Names make your code more expressive:
Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, :height, :width])We created a tensor of the shape {3, 3}, and two axes named height and width.
You can also leave dimension names as nil (which is the default):
Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, nil, nil])However, you must provide a name for every dimension in the tensor. For example, the following code snippet raises an error because 1 name is given, but there are 3 dimensions:
Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch])Indexing and Slicing tensor values
We can get any cell of the tensor:
tensor = Nx.tensor([[1, 2], [3, 4]], names: [:y, :x])
tensor[[0, 1]]Negative indices will start counting from the end of the axis.
-1 is the last entry, -2 the second to last and so on.
tensor = Nx.tensor([[1, 2], [3, 4], [5, 6]], names: [:y, :x])
tensor[[-1, -1]]We can also get a whole dimension:
tensor[x: 1]or a range:
tensor[y: 0..1]tensor[[.., 1]] will achieve the same result as tensor[x: 1].
This is because Elixir has the syntax sugar .. for a 0..-1//1 range.
Tensor shape and reshape
Nx.shape(tensor)We can also create a new tensor with the given shape using Nx.reshape/2:
Nx.reshape(tensor, {1, 4}, names: [:batches, :values])This operation generally reuses all of the tensor data and simply changes the metadata, so it has no notable cost. The new tensor has the same type, but a new shape.
Floats and Complex numbers
Besides single-precision (32 bits) floats, Nx floating-point numbers can also have other kinds of precision, such as half-precision (16) or double-precision (64):
Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f16)Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f64)Brain floats are also supported:
Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :bf16)Certain backends and compilers support 8-bit floats. The precision
implementation of 8-bit floats may change per backend, so you must be careful
when transferring data across different backends. Nx.BinaryBackend implements F8E5M2:
Nx.tensor([1, 2, 3], type: :f8)In all cases, the non-finite values negative infinity (-Inf), infinity (Inf),
and "not a number" (NaN) can be represented by the atoms :neg_infinity,
:infinity, and :nan, respectively:
Nx.tensor([:neg_infinity, :nan, :infinity])Finally, complex numbers are also supported in tensors, in both 32-bit and 64-bit precision:
Nx.tensor(Complex.new(1, -1))Check out the documentation for Nx.tensor/2 for more documentation on the accepted options.
Basic operations
Nx supports element-wise arithmetic operations for tensors and broadcasting when necessary.
Addition
Nx.add/2: Adds corresponding elements of two tensors.
a = Nx.tensor([1, 2, 3])
b = Nx.tensor([0, 1, 2])
Nx.add(a, b)Subtraction
Nx.subtract/2: Subtracts the elements of the second tensor from the first.
a = Nx.tensor([10, 20, 30])
b = Nx.tensor([0, 1, 2])
Nx.subtract(a, b)Multiplication
Nx.multiply/2: Multiplies corresponding elements of two tensors.
a = Nx.tensor([2, 3, 4])
b = Nx.tensor([0, 1, 2])
Nx.multiply(a, b)Division
Nx.divide/2: Divides the elements of the first tensor by the second tensor.
a = Nx.tensor([10, 30, 40])
b = Nx.tensor([5, 6, 8])
Nx.divide(a, b)Exponentiation
Nx.pow/2: Raises each element of the first tensor to the power of the corresponding element in the second tensor.
a = Nx.tensor([2, 3, 4])
b = Nx.tensor([2])
Nx.pow(a, b)Quotient
Nx.quotient/2: Returns a new tensor where each element is the integer division (div/2).
a = Nx.tensor([10, 20, 30])
b = Nx.tensor([3, 7, 4])
Nx.quotient(a, b)Remainder
Nx.remainder/2: Computes the integer division remainder.
a = Nx.tensor([27, 32, 43])
b = Nx.tensor([2, 3, 4])
Nx.remainder(a , b)Negation
Nx.negate/1: Negates each element of a tensor.
a = Nx.tensor([2, 3, 4])
Nx.negate(a)Square Root
Nx.sqrt/1: Computes the element-wise square root.
a = Nx.tensor([4, 9, 16])
Nx.sqrt(a)Element-Wise Comparison
The following operations returns a u8 tensor where 1 represents true and 0 represents false.
Equality and Inequality
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 9, -16])
Nx.equal(a, b)a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.0, 9.0, -16.0])
Nx.not_equal(a, b)Greater and Less
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 8, 17])
Nx.greater(a, b)a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.2, 9.0, 15.9])
Nx.less(a, b)Greater_Equal and Less_Equal
Nx.greater_equal/2, Nx.less_equal/2
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 8, 17])
Nx.greater_equal(a, b)a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.2, 9.0, 15.9])
Nx.less_equal(a, b)Aggregate functions
These operations aggregate values across tensor axes.
See also the aggregation guide for a more in-depth exploration on the subject.
Sum
Nx.sum/1: Sums all elements.
a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.sum(a)Mean
Nx.mean/1: Computes the mean value of the elements.
a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.mean(a)Product
Nx.product/1: Computes the product of the elements.
a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.product(a)Matrix Multiplication
Nx.dot/4: Computes the generalized dot product between two tensors, operating on specific contracting axes.
t1 = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y])
t2 = Nx.tensor([[10, 20], [30, 40]], names: [:height, :width])
Nx.dot(t1, [0], t2, [0])