View Source Broadcasting
The dimensions of tensors in an operator don't always match.
For example, you might want to subtract a 1
from every
element of a {2, 2}
-shaped tensor, like this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - 1 = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$
Mathematically, this is the same as:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 1 \\\\ 1 & 1 \end{bmatrix} = \begin{bmatrix} 0 & 1 \\\\ 2 & 3 \end{bmatrix} $$
This means we need a way to convert 1
to a {2, 2}
-shaped tensor.
Nx.broadcast/2
solves that problem. This function takes
a tensor or a scalar and a shape.
Mix.install([
{:nx, "~> 0.9"}
])
Nx.broadcast(1, {2, 2})
This call takes the scalar 1
and translates it
to a compatible shape by copying it. Sometimes, it's easier
to provide a tensor as the second argument, and let broadcast/2
extract its shape:
tensor = Nx.tensor([[1, 2], [3, 4]])
Nx.broadcast(1, tensor)
The code broadcasts 1
to the shape of tensor
. In many operators
and functions, the broadcast happens automatically:
Nx.subtract(tensor, 1)
This result is possible because Nx broadcasts both tensors
in subtract/2
to compatible shapes. That means you can provide
scalar values as either argument:
Nx.subtract(10, tensor)
Or subtract a row or column. Mathematically, it would look like this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$
which is the same as this:
$$ \begin{bmatrix} 1 & 2 \\\\ 3 & 4 \end{bmatrix} - \begin{bmatrix} 1 & 2 \\\\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 0 & 0 \\\\ 2 & 2 \end{bmatrix} $$
This rewrite happens in Nx as well, through a broadcast operation. We want to
broadcast the tensor [1, 2]
to match the {2, 2}
shape:
Nx.broadcast(Nx.tensor([1, 2]), {2, 2})
The subtract
function in Nx
takes care of that broadcast
implicitly, as discussed above:
Nx.subtract(tensor, Nx.tensor([1, 2]))
The broadcast worked as expected, copying the [1, 2]
row
enough times to fill a {2, 2}
-shaped tensor. A tensor with a
dimension of 1
will broadcast to fill the tensor:
[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})
[[[1, 2, 3]]]
|> Nx.tensor()
|> Nx.broadcast({4, 2, 3})
Both of these examples copy parts of the tensor enough times to fill out the broadcast shape. You can check out the Nx broadcasting documentation for more details:
h Nx.broadcast
Much of the time, you won't have to broadcast yourself. Many of the functions and operators Nx supports will do so automatically.
We can use tensor-aware operators via various Nx
functions and
many of them implicitly broadcast tensors.
Throughout this section, we have been invoking Nx.subtract/2
and
our code would be more expressive if we could use its equivalent
mathematical operator. Fortunately, Nx provides a way. Next, we'll
dive into numerical definitions using defn
.