View Source Introduction to Vectorization

Mix.install([{:nx, github: "elixir-nx/nx", sparse: "nx"}])

What is vectorization?

Vectorization in Nx is the concept of imposing "for-each" semantics into leading tensor axes. This enables writing tensor operations in a simpler way, avoiding while loops, which also leading to more efficient code, as we'll see in this guide.

There are a few functions related to vectorization. First we'll look into Nx.vectorize/2, Nx.devectorize/2, and Nx.revectorize/3.

For these examples, we'll utilize the normalize function defined below, which only works with 1D tensors. The function subtracts the minimum value and divides the tensor by the resulting maximum.

defmodule Example do
  import Nx.Defn

  defn normalize(t) do
    case Nx.shape(t) do
      {_} -> :ok
      _ -> raise "invalid shape"
    end

    min = Nx.reduce_min(t)
    zero_min = Nx.subtract(t, min)
    Nx.divide(zero_min, Nx.reduce_max(zero_min))
  end
end

We can invoke it as:

Example.normalize(Nx.tensor([1, 2, 3]))

However, if we attempt to call it for any tensor with more than one dimension, it won't work. Let's first define a 2D tensor:

t =
  Nx.tensor([
    [1, 2, 3],
    [10, 20, 30],
    [4, 5, 6]
  ])

Now if we attempt to invoke it:

Example.normalize(t)

To address this, we can vectorize the tensor:

vec = Nx.vectorize(t, :rows)

As we can see above, the newly vectorized vec is the same tensor as t, but the first axis is now a vectorized axis called rows, with size 3. This means that for all intents and purposes, we can think of this tensor as a 1D tensor, on which our normalize function will now work as if we passed those 3 rows separately (thus the "for-each" semantics mentioned above). Let's give it a try:

normalized_vec = Example.normalize(vec)

While the tensor is vectorized, we can't treat it as a matrix (2D tensor):

right =
  Nx.tensor([
    [1, 2, 3],
    [2, 3, 4],
    [3, 4, 5]
  ])

# The results might be unexpected, but they behave the same as Nx.add(Nx.tensor([1, 2, 3]), right)
# and so on, resulting in a vectorized tensor with 3 inner matrices.
Nx.add(normalized_vec, right)

You can devectorize the tensor to get its original shape:

# If we want to keep the vectorized axes' names
Nx.devectorize(normalized_vec)
# If we want to drop the vectorized axes' names
devec = Nx.devectorize(normalized_vec, keep_names: false)

Once devectorized, we can effectively add the two matrices together:

Nx.add(devec, right)

Revectorization

Now that we have the basics down, let's discuss Nx.revectorize with multi dimensional tensors. This is especially useful in cases where we have multiple vectorization axes which we want to collapse and then re-expand.

# This version of vectorize also asserts on the size of the dimensions being vectorized
t = Nx.iota({2, 1, 3, 2, 2}) |> Nx.vectorize(x: 2, y: 1, z: 3)

Let's imagine we want to operate on the vectorized iota t above, adding a specific constant for each vectorized entry. We can do this through a bit of tensor introspection in conjunction with revectorize/3

collapsed_t = Nx.revectorize(t, [vectors: :auto], target_shape: t.shape)

[vectors: vectorized_size] = collapsed_t.vectorized_axes

constants =
  Nx.iota({vectorized_size})
  |> Nx.multiply(100)
  |> Nx.vectorize(vectors: vectorized_size)

{constants, collapsed_t}

From the output above, we can see that we have a vectorized tensor of scalars as well a vectorized tensor with the same vectorized axis, containing matrices.

Now we can add them together as we discussed above, and then re-expand the vectorized axes to the original shape.

collapsed_t
|> Nx.add(constants)
|> Nx.revectorize(t.vectorized_axes)

Before we move on to the last two vectorization functions, let's see how we can replace a while loop with vectorization. The following module defines the same function twice, once for each method.

defmodule WhileExample do
  import Nx.Defn

  defn while_sum_and_multiply(x, y) do
    out =
      case Nx.shape(x) do
        {n, _} ->
          Nx.broadcast(0, {n})

        _ ->
          raise "expected x to have rank 2"
      end

    n = Nx.axis_size(out, 0)

    case Nx.shape(y) do
      {^n} -> nil
      _ -> raise "expected y to have rank 1 and the same number of elements as x"
    end

    {out, _} =
      while {out, {x, y}}, i <- 0..(n - 1) do
        update = Nx.sum(x[i]) + y[i]

        updated = Nx.indexed_put(out, Nx.reshape(i, {1}), update)
        {updated, {x, y}}
      end

    out
  end

  defn vectorized_sum_and_multiply(x, y) do
    # for the sake of equivalence, we'll keep the limitation
    # of accepting only rank-2 tensors for x
    {n, m} =
      case Nx.shape(x) do
        {n, m} ->
          {n, m}

        _ ->
          raise "expected x to have rank 2"
      end

    case Nx.shape(y) do
      {^n} -> nil
      _ -> raise "expected y to have rank 1 and the same number of elements as x"
    end

    # this enables us to accept vectorized tensors for x and y
    vectorized_axes = x.vectorized_axes

    x = Nx.revectorize(x, [collapsed: :auto, vectors: n], target_shape: {m})
    y = Nx.revectorize(y, [collapsed: :auto, vectors: n], target_shape: {})

    out = Nx.sum(x) + y

    Nx.revectorize(out, vectorized_axes, target_shape: {n})
  end
end

Let's give those definitions a try:

x =
  Nx.tensor([
    [1, 2, 3],
    [1, 0, 0],
    [0, 1, 3]
  ])

y = Nx.tensor([4, 5, 6])

x_vec = Nx.tile(x, [2, 1]) |> Nx.reshape({2, 3, 3}) |> Nx.vectorize(:rows)
# use a different set of vectorized axes for y
y_vec = Nx.concatenate([y, Nx.multiply(y, 2)]) |> Nx.reshape({2, 3}) |> Nx.vectorize(:cols)

{
  WhileExample.while_sum_and_multiply(x, y),
  WhileExample.vectorized_sum_and_multiply(x, y),
  WhileExample.vectorized_sum_and_multiply(x_vec, y),
  WhileExample.vectorized_sum_and_multiply(x_vec, y_vec)
}

The advantage of vectorization is that the underlying compilers and hardware may do a much better job of optimizing tensor operations than while loops, which often run linearly.

The eagle-eyed reader will have noticed that we inadvertently threw away the vectorized axes that we received from y. This is because our usage of revectorize disregards the possibility that the axes could be different in each input tensor. Luckily, we can tackle this problem too.

Reshape and Broadcast Vectors

The bug introduced in the previous section can be solved through Nx.reshape_vectors/1 and Nx.broadcast_vectors/1. Both functions receive lists of tensors and will operate on their vectorized axes to ensure that the shapes are compatible in one way or the other.

Nx functions will natively do this for us, as we see below:

base = Nx.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
vec_i = Nx.vectorize(Nx.reshape(base, {10, 1}), i: 10, j: 1)
vec_j = Nx.vectorize(base, :j)

# easy way to build a multiplication table from 0 to 9
Nx.multiply(vec_i, vec_j)

Ok, now that we know that broadcasting will work on vectorized axes to add elements in our resulting tensor, we can look into the aforementioned functions.

a = Nx.iota({2, 1}, vectorized_axes: [x: 2, y: 1])
b = Nx.iota({}, vectorized_axes: [z: 2, x: 1])

Nx.reshape_vectors([a, b], align_ranks: true)

In the example above, we can see that both tensors end up containing vectorized axes :x, :y, and :z. Furthermore, we can also see that the b tensor was rearranged so that the :x axis comes first. All tensors end up with the same ordering of axes, but the axes aren't resized in any way.

Finally, the align_ranks: true option is passed so that the inner shape (the non-vectorized part!) of both tensors ends up in a broadcastable shape with the same rank across all tensors. The example uses only 2 tensors, but the list can have arbitrary size.

Nx.broadcast_vectors/2 works similarly, except it also broadcasts the dimensions instead of simply reshaping:

x = Nx.iota({2, 1}, vectorized_axes: [x: 2, y: 1])
y = Nx.iota({}, vectorized_axes: [z: 2, x: 1])

Nx.broadcast_vectors([x, y])

The key difference is that all vectorized axes will end up with the same size in all resulting tensors, effectively behaving the same as Nx.broadcast does for non-vectorized shapes. Specifically, we can see that the x tensor gets the new z: 2 axis and y gets both x: 2, y: 1, where the :x axis was already present, but has now been resized.

With this knowledge we can rewrite our function without the bug, as follows:

defmodule BroadcastVectorsExample do
  import Nx.Defn

  defn vectorized_sum_and_multiply(x, y) do
    # for the sake of equivalence, we'll keep the limitation
    # of accepting only rank-2 tensors for x

    [x, y] = Nx.broadcast_vectors([x, y])

    {n, m} =
      case Nx.shape(x) do
        {n, m} ->
          {n, m}

        _ ->
          raise "expected x to have rank 2"
      end

    case Nx.shape(y) do
      {^n} -> nil
      _ -> raise "expected y to have rank 1 and the same number of elements as x"
    end

    # this enables us to accept vectorized tensors for x and y
    vectorized_axes = x.vectorized_axes

    x = Nx.revectorize(x, [collapsed: :auto, vectors: n], target_shape: {m})
    y = Nx.revectorize(y, [collapsed: :auto, vectors: n], target_shape: {})

    out = Nx.sum(x) + y

    Nx.revectorize(out, vectorized_axes, target_shape: {n})
  end
end

Let's give it once again another try:

x =
  Nx.tensor([
    [1, 2, 3],
    [1, 0, 0],
    [0, 1, 3]
  ])

y = Nx.tensor([4, 5, 6])

x_vec = Nx.tile(x, [2, 1]) |> Nx.reshape({2, 3, 3}) |> Nx.vectorize(:rows)
# use a different set of vectorized axes for y
y_vec = Nx.concatenate([y, Nx.multiply(y, 2)]) |> Nx.reshape({2, 3}) |> Nx.vectorize(:cols)

{
  BroadcastVectorsExample.vectorized_sum_and_multiply(x_vec, y),
  BroadcastVectorsExample.vectorized_sum_and_multiply(x_vec, y_vec)
}