View Source Nx Quickstart
Prerequisites
To properly use Nx, you will need to know a bit of Elixir. For a refresher, check out the Elixir Getting Started Guide.
To work on the examples you can run using the "Run in Livebook" button in this page.
Learning Objectives
This is a overview of Nx tensors. In this section, we'll look at some of the various tools for creating and interacting with tensors.
After reading, you should be able to understand:
- Create 1, 2 and N-dimensional tensors in
Nx
; - How to index, slice and iterate through tensors;
- Basic tensor functions;
- How to apply some linear algebra operations to n-dimensional tensors without using for-loops;
- Axis and shape properties for n-dimensional tensors.
The Basics
First, let's install Nx with Mix.install
.
Mix.install([
{:nx, "~> 0.9"}
])
The IEx.Helpers
module will assist our exploration of core tensor concepts.
import IEx.Helpers
Creating tensors
The argument for Nx.tensor/1
must be one of:
- a tensor;
- a number (which means the tensor is scalar/zero-dimensional);
- a boolean (also scalar/zero-dimensional);
- an arbitrarily nested list of numbers and booleans
- the special atoms
:nan
,:infinity
,:neg_infinity
, which represent non-finite numbers which are not supported by Elixir floats.
If a new tensor is allocated, it will be allocated in the backend defined by the :backend
option.
If it is not provided, Nx.default_backend/0
will be used instead.
Examples
A number returns a tensor of zero dimensions, also known as a scalar:
Nx.tensor(0)
Nx.tensor(1.0)
A list returns a one-dimensional tensor, also known as a vector:
Nx.tensor([1, 2, 3])
Nx.tensor([1.2, 2.3, 3.4, 4.5])
Higher dimensional tensors are also possible:
Nx.tensor([[1, 2, 3], [4, 5, 6]])
Nx.tensor([[1, 2], [3, 4], [5, 6]])
Nx.tensor([[[1, 2], [3, 4], [5, 6]], [[-1, -2], [-3, -4], [-5, -6]]])
Tensors can also be given as inputs, which is useful for functions that don't care about the input kind:
Nx.tensor(Nx.tensor([1, 2, 3]))
Naming dimensions
You can provide names for tensor dimensions. Names are atoms:
Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
Names make your code more expressive:
Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, :height, :width])
We created a tensor of the shape {3, 3}
, and two axes named height
and width
.
You can also leave dimension names as nil
(which is the default):
Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch, nil, nil])
However, you must provide a name for every dimension in the tensor. For example, the following code snippet raises an error because 1 name is given, but there are 3 dimensions:
Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:batch])
Indexing and Slicing tensor values
We can get any cell of the tensor:
tensor = Nx.tensor([[1, 2], [3, 4]], names: [:y, :x])
tensor[[0, 1]]
Negative indices will start counting from the end of the axis.
-1
is the last entry, -2
the second to last and so on.
tensor = Nx.tensor([[1, 2], [3, 4], [5, 6]], names: [:y, :x])
tensor[[-1, -1]]
We can also get a whole dimension:
tensor[x: 1]
or a range:
tensor[y: 0..1]
tensor[[.., 1]]
will achieve the same result as tensor[x: 1]
.
This is because Elixir has the syntax sugar ..
for a 0..-1//1
range.
Tensor shape and reshape
Nx.shape(tensor)
We can also create a new tensor with the given shape using Nx.reshape/2
:
Nx.reshape(tensor, {1, 4}, names: [:batches, :values])
This operation generally reuses all of the tensor data and simply changes the metadata, so it has no notable cost. The new tensor has the same type, but a new shape.
Floats and Complex numbers
Besides single-precision (32 bits) floats, Nx floating-point numbers can also have other kinds of precision, such as half-precision (16) or double-precision (64):
Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f16)
Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :f64)
Brain floats are also supported:
Nx.tensor([0.0, 0.2, 0.4, 1.0], type: :bf16)
Certain backends and compilers support 8-bit floats. The precision
implementation of 8-bit floats may change per backend, so you must be careful
when transferring data across different backends. Nx.BinaryBackend
implements F8E5M2:
Nx.tensor([1, 2, 3], type: :f8)
In all cases, the non-finite values negative infinity (-Inf), infinity (Inf),
and "not a number" (NaN) can be represented by the atoms :neg_infinity
,
:infinity
, and :nan
, respectively:
Nx.tensor([:neg_infinity, :nan, :infinity])
Finally, complex numbers are also supported in tensors, in both 32-bit and 64-bit precision:
Nx.tensor(Complex.new(1, -1))
Check out the documentation for Nx.tensor/2
for more documentation on the accepted options.
Basic operations
Nx supports element-wise arithmetic operations for tensors and broadcasting when necessary.
Addition
Nx.add/2
: Adds corresponding elements of two tensors.
a = Nx.tensor([1, 2, 3])
b = Nx.tensor([0, 1, 2])
Nx.add(a, b)
Subtraction
Nx.subtract/2
: Subtracts the elements of the second tensor from the first.
a = Nx.tensor([10, 20, 30])
b = Nx.tensor([0, 1, 2])
Nx.subtract(a, b)
Multiplication
Nx.multiply/2
: Multiplies corresponding elements of two tensors.
a = Nx.tensor([2, 3, 4])
b = Nx.tensor([0, 1, 2])
Nx.multiply(a, b)
Division
Nx.divide/2
: Divides the elements of the first tensor by the second tensor.
a = Nx.tensor([10, 30, 40])
b = Nx.tensor([5, 6, 8])
Nx.divide(a, b)
Exponentiation
Nx.pow/2
: Raises each element of the first tensor to the power of the corresponding element in the second tensor.
a = Nx.tensor([2, 3, 4])
b = Nx.tensor([2])
Nx.pow(a, b)
Quotient
Nx.quotient/2
: Returns a new tensor where each element is the integer division (div/2
).
a = Nx.tensor([10, 20, 30])
b = Nx.tensor([3, 7, 4])
Nx.quotient(a, b)
Remainder
Nx.remainder/2
: Computes the integer division remainder.
a = Nx.tensor([27, 32, 43])
b = Nx.tensor([2, 3, 4])
Nx.remainder(a , b)
Negation
Nx.negate/1
: Negates each element of a tensor.
a = Nx.tensor([2, 3, 4])
Nx.negate(a)
Square Root
Nx.sqrt/1
: Computes the element-wise square root.
a = Nx.tensor([4, 9, 16])
Nx.sqrt(a)
Element-Wise Comparison
The following operations returns a u8 tensor where 1 represents true
and 0 represents false
.
Equality and Inequality
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 9, -16])
Nx.equal(a, b)
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.0, 9.0, -16.0])
Nx.not_equal(a, b)
Greater and Less
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 8, 17])
Nx.greater(a, b)
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.2, 9.0, 15.9])
Nx.less(a, b)
Greater_Equal and Less_Equal
Nx.greater_equal/2
, Nx.less_equal/2
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4, 8, 17])
Nx.greater_equal(a, b)
a = Nx.tensor([4, 9, 16])
b = Nx.tensor([4.2, 9.0, 15.9])
Nx.less_equal(a, b)
Aggregate functions
These operations aggregate values across tensor axes.
See also the aggregation guide for a more in-depth exploration on the subject.
Sum
Nx.sum/1
: Sums all elements.
a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.sum(a)
Mean
Nx.mean/1
: Computes the mean value of the elements.
a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.mean(a)
Product
Nx.product/1
: Computes the product of the elements.
a = Nx.tensor([[4, 9, 16], [4.2, 9.0, 16.7]])
Nx.product(a)
Matrix Multiplication
Nx.dot/4
: Computes the generalized dot product between two tensors, operating on specific contracting axes.
t1 = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y])
t2 = Nx.tensor([[10, 20], [30, 40]], names: [:height, :width])
Nx.dot(t1, [0], t2, [0])