View Source Tensor Aggregation 101
Mix.install([
{:nx, "~> 0.5"}
])
Aggregation operations on tensors
import Nx, only: :sigils
You can apply aggregation functions on any tensor. The functions can be applied to the tensor as a whole or to a subsection of the tensor taken in an axis-wise fashion.
As a first example, let's take a 2D tensor of shape {2, 3}
. Notice that we can name the axes. The elements of the tensor t
are the t[i][j]
when $i$ is the row, and $j$ the column.
m = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
You can get the maximum number in a tensor with Nx.reduce_max(m), returning a 0D tensor. With a reduction, we lose a dimension per axis reduced, and since we applied the reduction globally, we lose all the dimensions. It should return Nx.Tensor(6)
here.
We can get the maximum number for each row with Nx.reduce_max(matrix, axes: [:y])
, returning a 1D tensor of size 2. Why :y
? Because for each row, we reduce along the :y
axis.
We can get the maximum number for each column with Nx.reduce_max(matrix, axes: [:x])
, returning a 1D tensor of size 3. For each column, we reduce along the axis :x
.
max = Nx.reduce_max(m)
max_x = Nx.reduce_max(m, axes: [:x])
max_y = Nx.reduce_max(m, axes: [:y])
%{max: max, max_x: max_x, max_y: max_y}
Let's consider another example with Nx.weighted_mean. It supports full-tensor and per axis operations. We display how to compute the weighted mean aggregate of a matrix with the example below of a 2D tensor of shape {2,2}
labeled m
:
m = ~M(1 2
3 4)
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka Hadamard product, an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
w = ~M[
10 20
30 40
]
nx_w_avg = Nx.weighted_mean(m, w)
man_w_avg = (1 * 10 + 2 * 20 + 3 * 30 + 4 * 40) / (10 + 20 + 30 + 40)
# 300/100
%{
nx_weighted_avg: Nx.to_number(nx_w_avg),
manual_weighted_avg: man_w_avg
}
The weighted mean can be computed per axis. Let's compute it along the first axis (axes: [0]
): you calculate "by column", so you aggregate/reduce along the first axis:
w = ~M[
10 20
30 40
]
w_avg_x = Nx.weighted_mean(m, w, axes: [0])
man_w_avg_x = [(1 * 10 + 3 * 30) / (10 + 30), (2 * 20 + 4 * 40) / (20 + 40)]
# [100/40, 200/60]
{
w_avg_x,
man_w_avg_x
}
We calculate weighted mean of a square matrix along the second axis (axes: [1]
): you calculate per row, so you aggregate/reduce along the second axis.
w = ~M[
10 20
30 40
]
nx_w_avg_y = Nx.weighted_mean(m, w, axes: [1])
man_w_avg_y = [(1 * 10 + 2 * 20) / (10 + 20), (3 * 30 + 4 * 40) / (30 + 40)]
# [ 50/30, 250/70]
{
nx_w_avg_y,
man_w_avg_y
}
Example with higher rank
The example below will be used through the documentation.
Take a list of numbers of length $n=36$, and turn it into a tensor with the shape {2,3,2,3}
; this is a tensor of rank 4. We will name the axes, with names: [:x, :y, :z, :t]
.
t = Nx.tensor(
[
[
[
[1, 2, 3],
[1, -2, 3]
],
[
[4, -3, 2],
[-4, 3, 2]
],
[
[5, -1, 3],
[-5, 1, 3]
],
],
[
[
[4, 6, 1],
[4, -6, 1]
],
[
[1, 2, 3],
[1, -2, 3]
],
[
[4, -3, 2],
[-4, 3, 2]
]
]
],
names: [:x, :y, :z, :t]
)
With the shape {2,3,2,4}
, you obtain slices of length $2$,then $3$, then $2$ and, finally, $3$.
The picture below will help to understand the aggregations.
$$ \begin{bmatrix} x=0 & \begin{bmatrix} y=0 & \begin{bmatrix} z=0 & a{0,0,0,0} & a{0,0,0,1} & a{0,0,0,2} \ z=1 & a{0,0,1,0} & a{0,0,1,1} & a{0,0,1,2} \end{bmatrix}\ y=1 & \begin{bmatrix} z=0 & a{0,1,0,0} & a{0,1,0,1} & a{0,1,0,2} \ z=1 & a{0,1,1,0} & a{0,1,1,1} & a{0,1,1,2} \ \end{bmatrix}\ y=2 & \begin{bmatrix} z=0 & a{0,2,0,0} & a{0,2,0,1} & a{0,2,0,2} \ z=1 & a{0,2,1,0} & a{0,2,1,1} & a{0,2,1,2} \ \end{bmatrix} \end{bmatrix} \ x=1 & \begin{bmatrix} y=0 & \begin{bmatrix} z=0 & a{1,0,0,0} & a{1,0,0,1} & a{1,0,0,2} \ z=1 & a{1,0,1,0} & a{1,0,1,1} & a{1,0,1,2} \ \end{bmatrix}\ y=1 & \begin{bmatrix} z=0 & a{1,1,0,0} & a{1,1,0,1} & a{1,1,0,2} \ z=1 & a{1,1,1,0} & a{1,1,1,1} & a{1,1,1,2} \ \end{bmatrix}\ y=2 & \begin{bmatrix} z=0 & a{1,2,0,0} & a{1,2,0,1} & a{1,2,0,2} \ z=1 & a{1,2,1,0} & a{1,2,1,1} & a{1,2,1,2} \ \end{bmatrix} \end{bmatrix} \end{bmatrix} \equiv \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} & 1, 2, 3, & \ & 1, -2, 3, & \end{bmatrix} \ \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix}\ \begin{bmatrix} & 5, -1, 3 &\ & -5, 1, 3 & \end{bmatrix} \end{bmatrix} \ \begin{bmatrix} \begin{bmatrix} & 4, 6, 1 & \ & 4, -6, 1 & \end{bmatrix}\ \begin{bmatrix} & 1, 2, 3 & \ & 1, -2, 3 & \end{bmatrix}\ \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix} \end{bmatrix} \end{bmatrix} $$
Firstly, let's check some characteristics of this tensor using full-tensor aggregation functions.
%{
dimensions: %{
x: Nx.axis_size(t, :x),
y: Nx.axis_size(t, :y),
z: Nx.axis_size(t, :z),
t: Nx.axis_size(t, :t)
},
n: Nx.size(t),
shape: Nx.shape(t),
rank: Nx.rank(t),
analysis: %{
most_frequent_nb: Nx.mode(t),
smallest_element: %{
value: Nx.reduce_min(t),
position: Nx.argmin(t)
},
greatest_element: %{
value: Nx.reduce_max(t),
position: Nx.argmax(t)
},
no_zero_nb: Nx.all(t)
}
}
Single row aggregation, along an axis
❗ We are going to use the key
:axis
below. The following functionsargmin
,argmax
,median
andmode
use:axis: <symbol> or number
in singular mode. All of the other aggregating functions useaxes: [<symbol> or number]
; these have multi-axis aggregation implemented.
When you aggregate along an axis, you are going to reshape the tensor and to aggregate with a function, in other words perform a reduction. It is important to consider the ordering of the axes.
Aggregate along the first axis axis: :x
(equivalently axis: 0
).
The shape of our original tensor t
is [x: 2, y: 3, z: 2, t: 3]
. When we aggregate along the axis :x
, it collapses and the shape of the resultant tensor r
is [y: 3,z: 2,t: 3]
. The rule is that every axis before the selected one will remain in the structure of the tensor as well as every axis after the selected one.
How? You collect from each slice of x
the elements that have the same remaining indexes "on the right" $y,z,t$. In the case of our tensor, you have 2 x
-slices so you build a sublist of 2 elements, and then apply the aggregating function on it. This gives you the value of the resulting tensor $r$ at the location $r[y][z][t]$:
$$ \rm{agg}\big(t[0][y][z][t], t[1][y][z][t]\big) = r[y][z][t] $$
The aggregating function used below is argmin: it returns the index of the smallest element of the sublist. In case of several occurences, the :tie_break
attribute takes the lowest index by default.
$$ \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} & \boxed{1}, 2, 3, & \ & 1, -2, 3, & \end{bmatrix} \ \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix} \ \begin{bmatrix} & 5, -1, 3 &\ & -5, 1, 3 & \end{bmatrix} \end{bmatrix} \ \begin{bmatrix} \begin{bmatrix} & \boxed{4}, 6, 1 & \ & 4, -6, 1 & \end{bmatrix} \ \begin{bmatrix} & 1, 2, 3 & \ & 1, -2, 3 & \end{bmatrix} \ \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix} \end{bmatrix} \end{bmatrix} \to \begin{bmatrix} \begin{bmatrix} & \boxed{\rm{agg}(1,4)}, \rm{agg}(2,6), \rm{agg}(3, 1)& \ & \rm{agg}(1,4), \rm{agg}(-2,-6), \rm{agg}(3,1) & \end{bmatrix} \ \begin{bmatrix} & \rm{agg}(4,1), \rm{agg}(-3,2), \rm{agg}(2,3) & \ & \rm{agg}(-4,1), \rm{agg}(3,-2), \rm{agg}(2,3) & \end{bmatrix} \ \begin{bmatrix} & \rm{agg}(5,4), \rm{agg}(-1,-3), \rm{agg}(3,2) &\ & \rm{agg}(-5,-4), \rm{agg}(1,3), \rm{agg}(3,2) & \end{bmatrix}\ \end{bmatrix} \xrightarrow[]{\rm{argmin}} \begin{bmatrix} \begin{bmatrix} & 0,0, 1 & \ & 0, 1,1 & \end{bmatrix} \ \begin{bmatrix} & 1, 0,0 & \ & 0,1,0 & \end{bmatrix} \ \begin{bmatrix} & 1, 1,1 &\ & 0, 0,1 & \end{bmatrix} \end{bmatrix} $$
Nx.argmin(t, axis: :x)
Aggregate along the second axis axis: :y
(equivalently axis: 1
) .
The axis y
will collapse and the shape of the resultant tensor r
is [x: 2, z: 2, t: 3]
.
How? The axis $x$ is before the selected one $y$ so it will remain. We therefor apply the procedure as above for each sub $x$-slice.
More precisely, fix the $x$-slice, and let's consider the $x=1$ one. You collect in each sub y
-slice the elements with the same remaining indexes "on the right", $z,t$ (and $x$ of course). You apply the aggregation and the result will be the value at location $r[1][z][t]$ of the resultant tensor $r$. So the elements of the slice $x=1$ of the resultant tensor $r$ will be:
$$ r[x{=1}][z][t] = \rm{agg}\big([t[x{=1}][0][z][t], t[x{=1}][1][z][t], t[x{=1}][2][z][t]\big) $$
You repeat this operation for each $x$-slice.
In the example below, we used again the argmin
function.
$$ \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} & \boxed{1}, 2, 3, & \ & 1, -2, 3, & \end{bmatrix}\ \begin{bmatrix} & \boxed{4}, -3, 2 & \ & -4, 3, 2 & \end{bmatrix}\ \begin{bmatrix} & \boxed{5}, -1, 3 &\ & -5, 1, 3 & \end{bmatrix} \end{bmatrix}\ \begin{bmatrix} \begin{bmatrix} & 4, 6, 1 & \ & 4, \boxed{-6}, 1 & \end{bmatrix}\ \begin{bmatrix} & 1, 2, 3 & \ & 1, \boxed{-2}, 3 & \end{bmatrix}\ \begin{bmatrix} & 4, -3, 2 & \ & -4, \boxed{3}, 2 & \end{bmatrix} \end{bmatrix} \end{bmatrix} \to \begin{bmatrix} \begin{bmatrix} &\boxed{\rm{agg}(1,4,5)}, \rm{agg}(2,-3,-1), \rm{agg}(3, 2,3)& \ & \rm{agg}(1,-4,-5), \rm{agg}(-2,3,1), \rm{agg}(3,2,3) & \end{bmatrix}\ \begin{bmatrix} & \rm{agg}(4,1,4), \rm{agg}(6,2,-3), \rm{agg}(1,3,2) & \ & \rm{agg}(4,1,-4), \boxed{\rm{agg}(-6,-2,3)}, \rm{agg}(1,3,2) & \end{bmatrix} \end{bmatrix} \xrightarrow[]{\rm{argmin}} \begin{bmatrix} \begin{bmatrix} & 0, 2, 1 & \ & 2, 0,1 & \end{bmatrix}\ \begin{bmatrix} & 1, 2,0 & \ & 2,0,0 & \end{bmatrix} \end{bmatrix} $$
Nx.argmin(t, axis: 1)
Aggregate along the axis: :z
axis (equivalently axis: 2
).
The z
axis collapse and the shape of the resultant tensor r
is [x: 2, y: 3, t: 3]
. You now have understood that we will, for each $x$-slices and each $y$-slices on the left, collect the elements of the slice with the same remain indexes "on the right", thus $t$ here.
$$ r[x][y][t] = \rm{agg}\big(t[x][y][0][t], (t[x][y][1][t] \big) $$
In the example below, we used the argmax
function.
$$ \begin{bmatrix} x=0 & \begin{bmatrix} y=0 & \begin{bmatrix} \boxed{1}, 2, 3, & \ \boxed{1}, -2, 3, & \end{bmatrix} \ y=1 & \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix}\ y=2 & \begin{bmatrix} & 5, -1, 3 &\ & -5, 1, 3 & \end{bmatrix} \end{bmatrix}\ x=1 & \begin{bmatrix} y=0 & \begin{bmatrix} & 4, \boxed{6}, 1 \ & 4,\boxed{-6}, 1 \end{bmatrix}\ y=1 & \begin{bmatrix} & 1, 2, 3 & \ & 1, -2, 3 & \end{bmatrix}\ y=2 & \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix} \end{bmatrix} \end{bmatrix} \to \begin{bmatrix} x=0 & \begin{bmatrix} & \boxed{\rm{agg}(1,1)}, \rm{agg}(2,-2), \rm{agg}(3,3)& \ & \rm{agg}(4,-4), \rm{agg}(-3,3), \rm{agg}(2,2) & \ & \rm{agg}(5,-5), \rm{agg}(-1,1), \rm{agg}(3,3) \end{bmatrix}\ x=1 & \begin{bmatrix} & \rm{agg}(4,4), \boxed{\rm{agg}(6,-6)}, \rm{agg}(1,1) & \ & \rm{agg}(1,1), \rm{agg}(2,-2), \rm{agg}(3,3) & \ & \rm{agg}(4,-4), \rm{agg}(-3,3), \rm{agg}(2,2) & \end{bmatrix} \end{bmatrix} \xrightarrow[]{\rm{argmax}} \begin{bmatrix} \begin{bmatrix} 0, 0, 0 \ 0, 1, 0 \ 0, 1, 0 \end{bmatrix}\ \begin{bmatrix} 0, 0, 0 \ 0, 0, 0 \ 0, 1, 0 \end{bmatrix} \end{bmatrix} $$
Nx.argmax(t, axis: :z)
Aggregate along the last axis, axis: :t
or (axis: 3
or axis: -1
).
This will reshape the tensor to [x: 2, y: 3, z: 2]
. Since this is the last index, having in mind the matrix picture above, you aggregate along each row:
$$ r[x][y][z] = \rm{agg}\big(t[x][y][z][0], t[x][y][z][1], t[x][y][z][2]\big) $$
and repeat this for each $x$-slice, each sub $y$-slice, each sub-sub $z$-slice.
Below is the result for the function argmin
.
$$ \begin{bmatrix} x=0 & \begin{bmatrix} y=0 & \begin{bmatrix} & \boxed{1, 2, 3}, & \ & 1, -2, 3, & \end{bmatrix}\ y=1 & \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix}\ y=2 & \begin{bmatrix} & 5, -1, 3 &\ & -5, 1, 3 & \end{bmatrix} \end{bmatrix}\ x=1 & \begin{bmatrix} y=0 & \begin{bmatrix} & 4, 6, 1 & \ & 4,-6, 1 & \end{bmatrix}\ y=1 & \begin{bmatrix} & 1, 2, 3 & \ & \boxed{1, -2, 3} & \end{bmatrix}\ y=2 & \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix} \end{bmatrix} \end{bmatrix} \xrightarrow[]{\rm{argmin}} \begin{bmatrix} \begin{bmatrix} & \boxed{0}, 1 & \ & 1, 0 & \ & 1, 0 \end{bmatrix}\ \begin{bmatrix} & 2, 1 & \ & 0, \boxed{1} & \ & 1, 0 \end{bmatrix} \end{bmatrix} $$
Nx.argmin(t, axis: 3)
Option :tie_break
You have the :tie_break
option to decide how to operate with you have several occurences of the result. It defaults to tie_break: :low
.
t4 = ~V[2 0 0 0 1]
%{
argmin_with_default: Nx.argmin(t4) |> Nx.to_number(),
argmin_with_tie_break_high_option: Nx.argmin(t4, tie_break: :high) |> Nx.to_number()
}
Option :keep_axis
Its default value is false
. When this option set to keep_axis: true
, you reduce but keep the dimension of the working axis to $1$. For example, with t
of shape {2,3,2,3}
, when you reduce along the third axis, :z
, we saw that the shape is {2,3,3}
but when you keep the axis, the shape is {2,3,1,3}
:
Nx.argmin(t, axis: 2, keep_axis: true) ==
Nx.argmin(t, axis: 2) |> Nx.reshape({2, 3, 1, 3}, names: [:x, :y, :z, :t])
$$ \begin{bmatrix} \begin{bmatrix} 0, 0, 0 \ 0, 1, 0 \ 0, 1, 0 \end{bmatrix}\ \begin{bmatrix} 0, 0, 0 \ 0, 0, 0 \ 0, 1, 0 \end{bmatrix} \end{bmatrix} \to \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} 0, 0, 0 \end{bmatrix}\ \begin{bmatrix} 0, 1, 0 \end{bmatrix}\ \begin{bmatrix} 0, 1, 0 \end{bmatrix} \end{bmatrix}\ \begin{bmatrix} \begin{bmatrix} 0, 0, 0 \end{bmatrix}\ \begin{bmatrix} 0, 0, 0 \end{bmatrix}\ \begin{bmatrix} 0, 1, 0 \end{bmatrix} \end{bmatrix} \end{bmatrix} $$
Multi-row aggregation
Suppose you want to aggregate along the axis x
and z
. Then you should get a tensor of shape [y: 3, t: 3]
. Given an index x
and z
, you aggregate all the numbers with the same indexes $(y,t)$.
$$ r[y][t] = \rm{agg}\big( t[0][y][0][t], t[0][y][1][t], t[1][y][0][t], t[1][y][1][t]\big) $$
$$ \begin{bmatrix} \begin{bmatrix} \begin{bmatrix} \boxed{1}, 2, 3 & \ \boxed{1}, -2, 3 & \end{bmatrix}\ \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix}\ \begin{bmatrix} & 5, -1, 3 &\ & -5, 1, 3 & \end{bmatrix} \end{bmatrix}\ \begin{bmatrix} \begin{bmatrix} \boxed{4}, 6, 1 & \ \boxed{4},-6, 1 & \end{bmatrix}\ \begin{bmatrix} & 1, 2, 3 & \ & 1, -2, 3 & \end{bmatrix}\ \begin{bmatrix} & 4, -3, 2 & \ & -4, 3, 2 & \end{bmatrix} \end{bmatrix} \end{bmatrix} \to \begin{bmatrix} \rm{agg}(1,1,4,4),\rm{agg}(2,-2,6,-6),\rm{agg}(3,3,1,1) &\ \rm{agg}(4,-4,1,1), \rm{agg}(-3,3,2,-2), \rm{agg}(2,2,3,3) & \ \rm{agg}(5,-5,4,-4) , \rm{agg}(-1,1,-3,3, \rm{agg}(3,3,2,2) \end{bmatrix} $$
From this, it is easier to undertand what the aggregation returns:
Nx.reduce_min(t, axes: [0, -1])
Nx.reduce_max(t, axes: [:x, :t])
Nx.mean(t, axes: [:x, :y])