View Source Introduction to Nx

  {:nx, "~> 0.5"}

Numerical Elixir

Elixir's primary numerical datatypes and structures are not optimized for numerical programming. Nx is a library built to bridge that gap.

Elixir Nx is a numerical computing library to smoothly integrate to typed, multidimensional data implemented on other platforms (called tensors). This support extends to the compilers and libraries that support those tensors. Nx has three primary capabilities:

  • In Nx, tensors hold typed data in multiple, named dimensions.
  • Numerical definitions, known as defn, support custom code with tensor-aware operators and functions.
  • Automatic differentiation, also known as autograd or autodiff, supports common computational scenarios such as machine learning, simulations, curve fitting, and probabilistic models.

Here's more about each of those capabilities. Nx tensors can hold unsigned integers (u8, u16, u32, u64), signed integers (s8, s16, s32, s64), floats (f32, f64), brain floats (bf16), and complex (c64, c128). Tensors support backends implemented outside of Elixir, including Google's Accelerated Linear Algebra (XLA) and LibTorch.

Numerical definitions have compiler support to allow just-in-time compilation that support specialized processors to speed up numeric computation including TPUs and GPUs.

To know Nx, we'll get to know tensors first. This rapid overview will touch on the major libraries. Then, future notebooks will take a deep dive into working with tensors in detail, autograd, and backends. Then, we'll dive into specific problem spaces like Axon, the machine learning library.

Nx and tensors

Systems of equations are a central theme in numerical computing. These equations are often expressed and solved with multidimensional arrays. For example, this is a two dimensional array:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} $$

Elixir programmers typically express a similar data structure using a list of lists, like this:

  [1, 2],
  [3, 4]

This data structure works fine within many functional programming algorithms, but breaks down with deep nesting and random access.

On top of that, Elixir numeric types lack optimization for many numerical applications. They work fine when programs need hundreds or even thousands of calculations. They tend to break down with traditional STEM applications when a typical problem needs millions of calculations.

In Nx, we express multi-dimensional data using typed tensors. Simply put, a tensor is a multi-dimensional array with a predetermined shape and type. To interact with them, Nx relies on tensor-aware operators rather than and Enum.reduce/3.

In this section, we'll look at some of the various tools for creating and interacting with tensors. The IEx helpers will assist our exploration of the core tensor concepts.

import IEx.Helpers

Now, everything is set up, so we're ready to create some tensors.

Creating tensors

Start out by getting a feel for Nx through its documentation. Do so through the IEx helpers, like this:

h Nx

Immediately, you can see that tensors are at the center of the API. The main API for creating tensors is Nx.tensor/2:

h Nx.tensor

We use it to create tensors from raw Elixir lists of numbers, like this:

tensor =
  |> Enum.chunk_every(2)
  |> Nx.tensor(names: [:y, :x])

The result shows all of the major fields that make up a tensor:

  • The data, presented as the list of lists [[1, 2], [3, 4]].
  • The type of the tensor, a signed integer 64 bits long, with the type s64.
  • The shape of the tensor, going left to right, with the outside dimensions listed first.
  • The names of each dimension.

We can easily convert it to a binary:

binary = Nx.to_binary(tensor)

A tensor of type s64 uses eight bytes for each integer. The binary shows the individual bytes that make up the tensor, so you can see the integers 1..4 interspersed among the zeros that make up the tensor. If all of our data only uses positive numbers from 0..255, we could save space with a different type:

Nx.tensor([[1, 2], [3, 4]], type: :u8) |> Nx.to_binary()

If you already have a binary, you can directly convert it to a tensor by passing the binary and the type:

Nx.from_binary(<<0, 1, 2>>, :u8)

This function comes in handy when working with published datasets because they must often be processed. Elixir binaries make quick work of dealing with numerical data structured for platforms other than Elixir.

We can get any cell of the tensor:


Now, try getting the first row of the tensor:

# ...your code here...

We can also get a whole dimension:

tensor[x: 1]

or a range:

tensor[y: 0..1]


  • create your own {3, 3} tensor with named dimensions
  • return a {2, 2} tensor containing the first two columns of the first two rows

We can get information about this most recent term with the IEx helper i, like this:

i tensor

The tensor is a struct that supports the usual Inspect protocol. The struct has keys, but we typically treat the Nx.Tensor as an opaque data type (meaning we typically access the contents and shape of a tensor using the tensor's API instead of the struct).

Primarily, a tensor is a struct, and the functions to access it go through a specific backend. We'll get to the backend details in a moment. For now, use the IEx h helper to get more documentation about tensors. We could also open a Code cell, type Nx.tensor, and hover the cursor over the word tensor to see the help about that function.

We can get the shape of the tensor with Nx.shape/1:


We can also create a new tensor with a new shape using Nx.reshape/2:

Nx.reshape(tensor, {1, 4}, names: [:batches, :values])

This operation reuses all of the tensor data and simply changes the metadata, so it has no notable cost.

The new tensor has the same type, but a new shape.

Now, reshape the tensor to contain three dimensions with one batch, one row, and four columns.

# ...your code here...

We can create a tensor with named dimensions, a type, a shape, and our target data. A dimension is called an axis, and axes can have names. We can specify the tensor type and dimension names with options, like this:

Nx.tensor([[1, 2, 3]], names: [:rows, :cols], type: :u8)

We created a tensor of the shape {1, 3}, with the type u8, the values [1, 2, 3], and two axes named rows and cols.

Now we know how to create tensors, so it's time to do something with them.

Tensor aware functions

In the last section, we created a s64[2][2] tensor. In this section, we'll use Nx functions to work with it. Here's the value of tensor:


We can use IEx.Helpers.exports/1 or code completion to find some functions in the Nx module that operate on tensors:

exports Nx

You might recognize that many of those functions have names that suggest that they would work on primitive values, called scalars. Indeed, a tensor can be a scalar:

pi = Nx.tensor(3.1415, type: :f32)

Take the cosine:


That function took the cosine of pi. We can also call them on a whole tensor, like this:


We can also call a function that aggregates the contents of a tensor. For example, to get a sum of the numbers in tensor, we can do this:


That's 1 + 2 + 3 + 4, and Nx went to multiple dimensions to get that sum. To get the sum of values along the x axis instead, we'd do this:

Nx.sum(tensor, axes: [:x])

Nx sums the values across the x dimension: 1 + 2 in the first row and 3 + 4 in the second row.


  • create a {2, 2, 2} tensor
  • with the values 1..8
  • with dimension names [:z, :y, :x]
  • calculate the sums along the y axis
# ...your code here...

Sometimes, we need to combine two tensors together with an operator. Let's say we wanted to subtract one tensor from another. Mathematically, the expression looks like this:

$$ \begin{bmatrix} 5 & 6 \\\\ 7 & 8 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} = \begin{bmatrix} 4 & 4 \\\\ 4 & 4 \end{bmatrix} $$

To solve this problem, subtract each right-hand integer from the corresponding left-hand integer. Unfortunately, we cannot use Elixir's built-in subtraction operator as it is not tensor-aware. Luckily, we can use the Nx.subtract/2 function to solve the problem:

tensor2 = Nx.tensor([[5, 6], [7, 8]])
Nx.subtract(tensor2, tensor)

We get a {2, 2} shaped tensor full of fours, exactly as we expected. When calling Nx.subtract/2, both operands had the same shape. Sometimes, you might want to process functions where the dimensions don't match. To solve this problem, Nx takes advantage of a concept called broadcasting.


Often, the dimensions of tensors in an operator don't match. For example, you might want to subtract a 1 from every element of a {2, 2} tensor, like this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - 1 = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$

Mathematically, it's the same as this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 1 \\\\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$

That means we need a way to convert 1 to a {2, 2} tensor. Nx.broadcast/2 solves that problem. This function takes a tensor or a scalar and a shape.

Nx.broadcast(1, {2, 2})

This broadcast takes the scalar 1 and translates it to a compatible shape by copying it. Sometimes, it's easier to provide a tensor as the second argument, and let broadcast/2 extract its shape:

Nx.broadcast(1, tensor)

The code broadcasts 1 to the shape of tensor. In many operators and functions, the broadcast happens automatically:

Nx.subtract(tensor, 1)

This result is possible because Nx broadcasts both tensors in subtract/2 to compatible shapes. That means you can provide scalar values as either argument:

Nx.subtract(10, tensor)

Or subtract a row or column. Mathematically, it would look like this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$

which is the same as this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$

This rewrite happens in Nx too, also through a broadcast. We want to broadcast the tensor [1, 2] to match the {2, 2} shape, like this:

Nx.broadcast(Nx.tensor([1, 2]), {2, 2})

The subtract function in Nx takes care of that broadcast implicitly, as before:

Nx.subtract(tensor, Nx.tensor([1, 2]))

The broadcast worked as advertised, copying the [1, 2] row enough times to fill a {2, 2} tensor. A tensor with a dimension of 1 will broadcast to fill the tensor:

[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})
[[[1, 2, 3]]]
|> Nx.tensor()
|> Nx.broadcast({4, 2, 3})

Both of these examples copy parts of the tensor enough times to fill out the broadcast shape. You can check out the Nx broadcasting documentation for more details:

h Nx.broadcast

Much of the time, you won't have to broadcast yourself. Many of the functions and operators Nx supports will do so automatically.

We can use tensor-aware operators via various Nx functions and many of them implicitly broadcast tensors.

Throughout this section, we have been invoking Nx.subtract/2 and our code would be more expressive if we could use its equivalent mathematical operator. Fortunately, Nx provides a way. Next, we'll dive into numerical definitions using defn.

Numerical definitions (defn)

The defn macro simplifies the expression of mathematical formulas containing tensors. Numerical definitions have two primary benefits over classic Elixir functions.

  • They are tensor-aware. Nx replaces operators like Kernel.-/2 with the Defn counterparts &mdash; which in turn use Nx functions optimized for tensors &mdash; so the formulas we express can use tensors out of the box.

  • defn definitions allow for building computation graph of all the individual operations and using a just-in-time (JIT) compiler to emit highly specialized native code for the desired computation unit.

We don't have to do anything special to get access to get tensor awareness beyond importing Nx.Defn and writing our code within a defn block.

To use Nx in a Mix project or a notebook, we need to include the :nx dependency and import the Nx.Defn module. The dependency is already included, so import it in a Code cell, like this:

import Nx.Defn

Just as the Elixir language supports def, defmacro, and defp, Nx supports defn. There are a few restrictions. It allows only numerical arguments in the form of primitives or tensors as arguments or return values, and supports only a subset of the language.

The subset of Elixir allowed within defn is quite broad, though. We can use macros, pipes, and even conditionals, so we're not giving up much when you're declaring mathematical functions.

Additionally, despite these small concessions, defn provides huge benefits. Code in a defn block uses tensor aware operators and types, so the math beneath your functions has a better chance to shine through. Numerical definitions can also run on accelerated numerical processors like GPUs and TPUs. Here's an example numerical definition:

defmodule TensorMath do
  import Nx.Defn

  defn subtract(a, b) do
    a - b

This module has a numerical definition that will be compiled. If we wanted to specify a compiler for this module, we could add a module attribute before the defn clause. One of such compilers is the EXLA compiler. You'd add the mix dependency for EXLA and do this:

@defn_compiler EXLA
defn subtract(a, b) do
  a - b

Now, it's your turn. Add a defn to TensorMath that accepts two tensors representing the lengths of sides of a right triangle and uses the pythagorean theorem to return the length of the hypotenuse. Add your function directly to the previous Code cell.

The last major feature we'll cover is called auto-differentiation, or autograd.

Automatic differentiation (autograd)

An important mathematical property for a function is the rate of change, or the gradient. These gradients are critical for solving systems of equations and building probabilistic models. In advanced math, derivatives, or differential equations, are used to take gradients. Nx can compute these derivatives automatically through a feature called automatic differentiation, or autograd.

Here's how it works.

h Nx.Defn.grad

We'll build a module with a few functions, and then create another function to create the gradients of those functions. The function grad/1 takes a function, and returns a function returning the gradient. We have two functions: poly/1 is a simple numerical definition, and poly_slope_at/1 returns its gradient:

$$ poly: f(x) = 3x^2 + 2x + 1 \\\\ $$

$$ polySlopeAt: g(x) = 6x + 2 $$

Here's the Elixir equivalent of those functions:

defmodule Funs do
  import Nx.Defn

  defn poly(x) do
    3 * Nx.pow(x, 2) + 2 * x + 1

  defn poly_slope_at(x) do

Notice the second defn. It uses grad/1 to take its derivative using autograd. It uses the intermediate defn AST and mathematical composition to compute the derivative. You can see it at work here:


Nice. If you plug the number 2 into the function $6x + 2$ you get 14! Said another way, if you look at the graph at exactly 2, the rate of increase is 14 units of poly(x) for every unit of x, precisely at x.

Nx also has helpers to get gradients corresponding to a number of inputs. These come into play when solving systems of equations.

Now, you try. Find a function computing the gradient of a sin wave.

# your code here