View Source Broadcasting

The dimensions of tensors in an operator don't always match. For example, you might want to subtract a 1 from every element of a {2, 2}-shaped tensor, like this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - 1 = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$

Mathematically, this is the same as:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 1 \\\\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$

This means we need a way to convert 1 to a {2, 2}-shaped tensor. Nx.broadcast/2 solves that problem. This function takes a tensor or a scalar and a shape.

Mix.install([
  {:nx, "~> 0.9"}
])


Nx.broadcast(1, {2, 2})

This call takes the scalar 1 and translates it to a compatible shape by copying it. Sometimes, it's easier to provide a tensor as the second argument, and let broadcast/2 extract its shape:

tensor = Nx.tensor([[1, 2], [3, 4]])
Nx.broadcast(1, tensor)

The code broadcasts 1 to the shape of tensor. In many operators and functions, the broadcast happens automatically:

Nx.subtract(tensor, 1)

This result is possible because Nx broadcasts both tensors in subtract/2 to compatible shapes. That means you can provide scalar values as either argument:

Nx.subtract(10, tensor)

Or subtract a row or column. Mathematically, it would look like this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$

which is the same as this:

$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$

This rewrite happens in Nx as well, through a broadcast operation. We want to broadcast the tensor [1, 2] to match the {2, 2} shape:

Nx.broadcast(Nx.tensor([1, 2]), {2, 2})

The subtract function in Nx takes care of that broadcast implicitly, as discussed above:

Nx.subtract(tensor, Nx.tensor([1, 2]))

The broadcast worked as expected, copying the [1, 2] row enough times to fill a {2, 2}-shaped tensor. A tensor with a dimension of 1 will broadcast to fill the tensor:

[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})
[[[1, 2, 3]]]
|> Nx.tensor()
|> Nx.broadcast({4, 2, 3})

Both of these examples copy parts of the tensor enough times to fill out the broadcast shape. You can check out the Nx broadcasting documentation for more details:

h Nx.broadcast

Much of the time, you won't have to broadcast yourself. Many of the functions and operators Nx supports will do so automatically.

We can use tensor-aware operators via various Nx functions and many of them implicitly broadcast tensors.

Throughout this section, we have been invoking Nx.subtract/2 and our code would be more expressive if we could use its equivalent mathematical operator. Fortunately, Nx provides a way. Next, we'll dive into numerical definitions using defn.