View Source Broadcasting
The dimensions of tensors in an operator don't always match.
For example, you might want to subtract a 1 from every
element of a {2, 2}-shaped tensor, like this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - 1 = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$
Mathematically, this is the same as:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 1 \\\\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$
This means we need a way to convert 1 to a {2, 2}-shaped tensor.
Nx.broadcast/2 solves that problem. This function takes
a tensor or a scalar and a shape.
Mix.install([
{:nx, "~> 0.9"}
])
Nx.broadcast(1, {2, 2})This call takes the scalar 1 and translates it
to a compatible shape by copying it. Sometimes, it's easier
to provide a tensor as the second argument, and let broadcast/2
extract its shape:
tensor = Nx.tensor([[1, 2], [3, 4]])
Nx.broadcast(1, tensor)The code broadcasts 1 to the shape of tensor. In many operators
and functions, the broadcast happens automatically:
Nx.subtract(tensor, 1)This result is possible because Nx broadcasts both tensors
in subtract/2 to compatible shapes. That means you can provide
scalar values as either argument:
Nx.subtract(10, tensor)Or subtract a row or column. Mathematically, it would look like this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$
which is the same as this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$
This rewrite happens in Nx as well, through a broadcast operation. We want to
broadcast the tensor [1, 2] to match the {2, 2} shape:
Nx.broadcast(Nx.tensor([1, 2]), {2, 2})The subtract function in Nx takes care of that broadcast
implicitly, as discussed above:
Nx.subtract(tensor, Nx.tensor([1, 2]))The broadcast worked as expected, copying the [1, 2] row
enough times to fill a {2, 2}-shaped tensor. A tensor with a
dimension of 1 will broadcast to fill the tensor:
[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})[[[1, 2, 3]]]
|> Nx.tensor()
|> Nx.broadcast({4, 2, 3})Both of these examples copy parts of the tensor enough times to fill out the broadcast shape. You can check out the Nx broadcasting documentation for more details:
h Nx.broadcastMuch of the time, you won't have to broadcast yourself. Many of the functions and operators Nx supports will do so automatically.
We can use tensor-aware operators via various Nx functions and
many of them implicitly broadcast tensors.
Throughout this section, we have been invoking Nx.subtract/2 and
our code would be more expressive if we could use its equivalent
mathematical operator. Fortunately, Nx provides a way. Next, we'll
dive into numerical definitions using defn.