View Source Introduction to Nx
Mix.install([
{:nx, "~> 0.5"}
])
Numerical Elixir
Elixir's primary numerical datatypes and structures are not optimized for numerical programming. Nx is a library built to bridge that gap.
Elixir Nx is a numerical computing library to smoothly integrate to typed, multidimensional data implemented on other platforms (called tensors). This support extends to the compilers and libraries that support those tensors. Nx has three primary capabilities:
- In Nx, tensors hold typed data in multiple, named dimensions.
- Numerical definitions, known as
defn
, support custom code with tensor-aware operators and functions. - Automatic differentiation, also known as autograd or autodiff, supports common computational scenarios such as machine learning, simulations, curve fitting, and probabilistic models.
Here's more about each of those capabilities. Nx tensors can hold unsigned integers (u8, u16, u32, u64), signed integers (s8, s16, s32, s64), floats (f32, f64), brain floats (bf16), and complex (c64, c128). Tensors support backends implemented outside of Elixir, including Google's Accelerated Linear Algebra (XLA) and LibTorch.
Numerical definitions have compiler support to allow just-in-time compilation that support specialized processors to speed up numeric computation including TPUs and GPUs.
To know Nx, we'll get to know tensors first. This rapid overview will touch on the major libraries. Then, future notebooks will take a deep dive into working with tensors in detail, autograd, and backends. Then, we'll dive into specific problem spaces like Axon, the machine learning library.
Nx and tensors
Systems of equations are a central theme in numerical computing. These equations are often expressed and solved with multidimensional arrays. For example, this is a two dimensional array:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} $$
Elixir programmers typically express a similar data structure using a list of lists, like this:
[
[1, 2],
[3, 4]
]
This data structure works fine within many functional programming algorithms, but breaks down with deep nesting and random access.
On top of that, Elixir numeric types lack optimization for many numerical applications. They work fine when programs need hundreds or even thousands of calculations. They tend to break down with traditional STEM applications when a typical problem needs millions of calculations.
In Nx, we express multi-dimensional data using typed tensors. Simply put,
a tensor is a multi-dimensional array with a predetermined shape and
type. To interact with them, Nx relies on tensor-aware operators rather
than Enum.map/2
and Enum.reduce/3
.
In this section, we'll look at some of the various tools for creating and interacting with tensors. The IEx helpers will assist our exploration of the core tensor concepts.
import IEx.Helpers
Now, everything is set up, so we're ready to create some tensors.
Creating tensors
Start out by getting a feel for Nx through its documentation. Do so through the IEx helpers, like this:
h Nx
Immediately, you can see that tensors are at the center of the
API. The main API for creating tensors is Nx.tensor/2
:
h Nx.tensor
We use it to create tensors from raw Elixir lists of numbers, like this:
tensor =
1..4
|> Enum.chunk_every(2)
|> Nx.tensor(names: [:y, :x])
The result shows all of the major fields that make up a tensor:
- The data, presented as the list of lists
[[1, 2], [3, 4]]
. - The type of the tensor, a signed integer 64 bits long, with the type
s64
. - The shape of the tensor, going left to right, with the outside dimensions listed first.
- The names of each dimension.
We can easily convert it to a binary:
binary = Nx.to_binary(tensor)
A tensor of type s64 uses eight bytes for each integer. The binary
shows the individual bytes that make up the tensor, so you can see
the integers 1..4
interspersed among the zeros that make
up the tensor. If all of our data only uses positive numbers from
0..255
, we could save space with a different type:
Nx.tensor([[1, 2], [3, 4]], type: :u8) |> Nx.to_binary()
If you already have a binary, you can directly convert it to a tensor by passing the binary and the type:
Nx.from_binary(<<0, 1, 2>>, :u8)
This function comes in handy when working with published datasets because they must often be processed. Elixir binaries make quick work of dealing with numerical data structured for platforms other than Elixir.
We can get any cell of the tensor:
tensor[0][1]
Now, try getting the first row of the tensor:
# ...your code here...
We can also get a whole dimension:
tensor[x: 1]
or a range:
tensor[y: 0..1]
Now,
- create your own
{3, 3}
tensor with named dimensions - return a
{2, 2}
tensor containing the first two columns of the first two rows
We can get information about this most recent term with
the IEx helper i
, like this:
i tensor
The tensor is a struct that supports the usual Inspect
protocol.
The struct has keys, but we typically treat the Nx.Tensor
as an opaque data type (meaning we typically access the contents and
shape of a tensor using the tensor's API instead of the struct).
Primarily, a tensor is a struct, and the
functions to access it go through a specific backend. We'll get to
the backend details in a moment. For now, use the IEx h
helper
to get more documentation about tensors. We could also open a Code
cell, type Nx.tensor, and hover the cursor over the word tensor
to see the help about that function.
We can get the shape of the tensor with Nx.shape/1
:
Nx.shape(tensor)
We can also create a new tensor with a new shape using Nx.reshape/2
:
Nx.reshape(tensor, {1, 4}, names: [:batches, :values])
This operation reuses all of the tensor data and simply changes the metadata, so it has no notable cost.
The new tensor has the same type, but a new shape.
Now, reshape the tensor to contain three dimensions with one batch, one row, and four columns.
# ...your code here...
We can create a tensor with named dimensions, a type, a shape, and our target data. A dimension is called an axis, and axes can have names. We can specify the tensor type and dimension names with options, like this:
Nx.tensor([[1, 2, 3]], names: [:rows, :cols], type: :u8)
We created a tensor of the shape {1, 3}
, with the type u8
,
the values [1, 2, 3]
, and two axes named rows
and cols
.
Now we know how to create tensors, so it's time to do something with them.
Tensor aware functions
In the last section, we created a s64[2][2]
tensor. In this section,
we'll use Nx functions to work with it. Here's the value of tensor
:
tensor
We can use IEx.Helpers.exports/1
or code completion to find
some functions in the Nx
module that operate on tensors:
exports Nx
You might recognize that many of those functions have names that suggest that they would work on primitive values, called scalars. Indeed, a tensor can be a scalar:
pi = Nx.tensor(3.1415, type: :f32)
Take the cosine:
Nx.cos(pi)
That function took the cosine of pi
. We can also call them
on a whole tensor, like this:
Nx.cos(tensor)
We can also call a function that aggregates the contents
of a tensor. For example, to get a sum of the numbers
in tensor
, we can do this:
Nx.sum(tensor)
That's 1 + 2 + 3 + 4
, and Nx went to multiple dimensions to get that sum.
To get the sum of values along the x
axis instead, we'd do this:
Nx.sum(tensor, axes: [:x])
Nx sums the values across the x
dimension: 1 + 2
in the first row
and 3 + 4
in the second row.
Now,
- create a
{2, 2, 2}
tensor - with the values
1..8
- with dimension names
[:z, :y, :x]
- calculate the sums along the
y
axis
# ...your code here...
Sometimes, we need to combine two tensors together with an operator. Let's say we wanted to subtract one tensor from another. Mathematically, the expression looks like this:
$$ \begin{bmatrix} 5 & 6 \\\\ 7 & 8 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} = \begin{bmatrix} 4 & 4 \\\\ 4 & 4 \end{bmatrix} $$
To solve this problem, subtract each right-hand integer from the
corresponding left-hand integer. Unfortunately, we cannot
use Elixir's built-in subtraction operator as it is not tensor-aware.
Luckily, we can use the Nx.subtract/2
function to solve the
problem:
tensor2 = Nx.tensor([[5, 6], [7, 8]])
Nx.subtract(tensor2, tensor)
We get a {2, 2}
shaped tensor full of fours, exactly as we expected.
When calling Nx.subtract/2
, both operands had the same shape.
Sometimes, you might want to process functions where the dimensions
don't match. To solve this problem, Nx takes advantage of
a concept called broadcasting.
Broadcasts
Often, the dimensions of tensors in an operator don't match.
For example, you might want to subtract a 1
from every
element of a {2, 2}
tensor, like this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - 1 = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$
Mathematically, it's the same as this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 1 \\\\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$
That means we need a way to convert 1
to a {2, 2}
tensor.
Nx.broadcast/2
solves that problem. This function takes
a tensor or a scalar and a shape.
Nx.broadcast(1, {2, 2})
This broadcast takes the scalar 1
and translates it
to a compatible shape by copying it. Sometimes, it's easier
to provide a tensor as the second argument, and let broadcast/2
extract its shape:
Nx.broadcast(1, tensor)
The code broadcasts 1
to the shape of tensor
. In many operators
and functions, the broadcast happens automatically:
Nx.subtract(tensor, 1)
This result is possible because Nx broadcasts both tensors
in subtract/2
to compatible shapes. That means you can provide
scalar values as either argument:
Nx.subtract(10, tensor)
Or subtract a row or column. Mathematically, it would look like this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$
which is the same as this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$
This rewrite happens in Nx too, also through a broadcast. We want to
broadcast the tensor [1, 2]
to match the {2, 2}
shape, like this:
Nx.broadcast(Nx.tensor([1, 2]), {2, 2})
The subtract
function in Nx
takes care of that broadcast
implicitly, as before:
Nx.subtract(tensor, Nx.tensor([1, 2]))
The broadcast worked as advertised, copying the [1, 2]
row
enough times to fill a {2, 2}
tensor. A tensor with a
dimension of 1
will broadcast to fill the tensor:
[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})
[[[1, 2, 3]]]
|> Nx.tensor()
|> Nx.broadcast({4, 2, 3})
Both of these examples copy parts of the tensor enough times to fill out the broadcast shape. You can check out the Nx broadcasting documentation for more details:
h Nx.broadcast
Much of the time, you won't have to broadcast yourself. Many of the functions and operators Nx supports will do so automatically.
We can use tensor-aware operators via various Nx
functions and
many of them implicitly broadcast tensors.
Throughout this section, we have been invoking Nx.subtract/2
and
our code would be more expressive if we could use its equivalent
mathematical operator. Fortunately, Nx provides a way. Next, we'll
dive into numerical definitions using defn
.
Numerical definitions (defn)
The defn
macro simplifies the expression of mathematical formulas
containing tensors. Numerical definitions have two primary benefits
over classic Elixir functions.
They are tensor-aware. Nx replaces operators like
Kernel.-/2
with theDefn
counterparts — which in turn useNx
functions optimized for tensors — so the formulas we express can use tensors out of the box.defn
definitions allow for building computation graph of all the individual operations and using a just-in-time (JIT) compiler to emit highly specialized native code for the desired computation unit.
We don't have to do anything special to get access to
get tensor awareness beyond importing Nx.Defn
and writing
our code within a defn
block.
To use Nx in a Mix project or a notebook, we need to include
the :nx
dependency and import the Nx.Defn
module. The
dependency is already included, so import it in a Code cell,
like this:
import Nx.Defn
Just as the Elixir language supports def
, defmacro
, and defp
,
Nx supports defn
. There are a few restrictions. It allows only
numerical arguments in the form of primitives or tensors as arguments
or return values, and supports only a subset of the language.
The subset of Elixir allowed within defn
is quite broad, though. We can
use macros, pipes, and even conditionals, so we're not giving up
much when you're declaring mathematical functions.
Additionally, despite these small concessions, defn
provides huge benefits.
Code in a defn
block uses tensor aware operators and types, so the math
beneath your functions has a better chance to shine through. Numerical
definitions can also run on accelerated numerical processors like GPUs and
TPUs. Here's an example numerical definition:
defmodule TensorMath do
import Nx.Defn
defn subtract(a, b) do
a - b
end
end
This module has a numerical definition that will be compiled.
If we wanted to specify a compiler for this module, we could add
a module attribute before the defn
clause. One of such compilers
is the EXLA compiler.
You'd add the mix
dependency for EXLA and do this:
@defn_compiler EXLA
defn subtract(a, b) do
a - b
end
Now, it's your turn. Add a defn
to TensorMath
that accepts two tensors representing the lengths of sides of a
right triangle and uses the pythagorean theorem to return the
length of the hypotenuse.
Add your function directly to the previous Code cell.
The last major feature we'll cover is called auto-differentiation, or autograd.
Automatic differentiation (autograd)
An important mathematical property for a function is the rate of change, or the gradient. These gradients are critical for solving systems of equations and building probabilistic models. In advanced math, derivatives, or differential equations, are used to take gradients. Nx can compute these derivatives automatically through a feature called automatic differentiation, or autograd.
Here's how it works.
h Nx.Defn.grad
We'll build a module with a few functions,
and then create another function to create the gradients of those
functions. The function grad/1
takes a function, and returns
a function returning the gradient. We have two functions: poly/1
is a simple numerical definition, and poly_slope_at/1
returns
its gradient:
$$ poly: f(x) = 3x^2 + 2x + 1 \\\\ $$
$$ polySlopeAt: g(x) = 6x + 2 $$
Here's the Elixir equivalent of those functions:
defmodule Funs do
import Nx.Defn
defn poly(x) do
3 * Nx.pow(x, 2) + 2 * x + 1
end
defn poly_slope_at(x) do
grad(&poly/1).(x)
end
end
Notice the second defn
. It uses grad/1
to take its
derivative using autograd. It uses the intermediate defn
AST
and mathematical composition to compute the derivative. You can
see it at work here:
Funs.poly_slope_at(2)
Nice. If you plug the number 2 into the function $6x + 2$
you get 14! Said another way, if you look at the graph at
exactly 2, the rate of increase is 14 units of poly(x)
for every unit of x
, precisely at x
.
Nx also has helpers to get gradients corresponding to a number of inputs. These come into play when solving systems of equations.
Now, you try. Find a function computing the gradient of a sin
wave.
# your code here