<!-- livebook:{"persist_outputs":true} -->

# Aggregation

```elixir
Mix.install([
  {:nx, "~> 0.9"}
])
```

```elixir
import Nx, only: :sigils
```

<!-- livebook:{"output":true} -->

```
Nx
```

## What is aggregation?

Aggregation is the process of reducing a tensor to a single value or a smaller tensor by applying specific operations across its dimensions. You can apply [aggregation functions](https://hexdocs.pm/nx/Nx.html#functions-aggregates) on any tensor. The functions can be applied to the tensor as a whole or to a subsection of the tensor taken in an axis-wise fashion.

As a first example, let's take a 2D tensor of shape `{2, 3}`. Notice that we can _name_ the axes. The elements of the tensor `t` are the `t[i][j]` when $i$ is the row, and $j$ the column.

```elixir
m = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[x: 2][y: 3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>
```

You can get the maximum number in a tensor with [Nx.reduce_max(m)](https://hexdocs.pm/nx/Nx.html#reduce_max/2), returning a 0D tensor. With a reduction, we lose a dimension per axis reduced, and since we applied the reduction globally, we lose all the dimensions. It should return `Nx.Tensor(6)` here.

We can get the maximum number for each _row_ with `Nx.reduce_max(matrix, axes: [:y])`, returning a 1D tensor of size 2. Why `:y`? Because for each row, we reduce along the `:y` axis.

We can get the maximum number for each _column_ with `Nx.reduce_max(matrix, axes: [:x])`, returning a 1D tensor of size 3. For each column, we reduce along the axis `:x`.

```elixir
max = Nx.reduce_max(m)
max_x = Nx.reduce_max(m, axes: [:x])
max_y = Nx.reduce_max(m, axes: [:y])
%{max: max, max_x: max_x, max_y: max_y}
```

<!-- livebook:{"output":true} -->

```
%{
  max: #Nx.Tensor<
    s64
    6
  >,
  max_x: #Nx.Tensor<
    s64[y: 3]
    [4, 5, 6]
  >,
  max_y: #Nx.Tensor<
    s64[x: 2]
    [3, 6]
  >
}
```

Let's consider another example with [Nx.weighted_mean](https://hexdocs.pm/nx/Nx.html#weighted_mean/3). It supports full-tensor and per axis operations. We display how to compute the _weighted mean aggregate_ of a matrix with the example below of a 2D tensor of shape `{2,2}` labeled `m`:

```elixir
m = ~MAT[
  1 2
  3 4
]
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>
```

First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.

```elixir
w = ~MAT[
  10 20
  30 40
]

nx_w_avg = Nx.weighted_mean(m, w)

man_w_avg = (1 * 10 + 2 * 20 + 3 * 30 + 4 * 40) / (10 + 20 + 30 + 40)
# 300/100

%{
  nx_weighted_avg: Nx.to_number(nx_w_avg),
  manual_weighted_avg: man_w_avg
}
```

<!-- livebook:{"output":true} -->

```
%{nx_weighted_avg: 3.0, manual_weighted_avg: 3.0}
```

The weighted mean can be computed _per axis_. Let's compute it along the _first_ axis (`axes: [0]`): you calculate "by column", so you aggregate/reduce along the first axis:

```elixir
w = ~MAT[
  10 20
  30 40
]

w_avg_x = Nx.weighted_mean(m, w, axes: [0])

man_w_avg_x = [(1 * 10 + 3 * 30) / (10 + 30), (2 * 20 + 4 * 40) / (20 + 40)]
# [100/40, 200/60]
{
  w_avg_x,
  man_w_avg_x
}
```

<!-- livebook:{"output":true} -->

```
{#Nx.Tensor<
   f32[2]
   [2.5, 3.3333332538604736]
 >, [2.5, 3.3333333333333335]}
```

We calculate weighted mean of a square matrix along the _second_ axis (`axes: [1]`): you calculate per row, so you aggregate/reduce along the second axis.

```elixir
w = ~MAT[
  10 20
  30 40
]

nx_w_avg_y = Nx.weighted_mean(m, w, axes: [1])

man_w_avg_y = [(1 * 10 + 2 * 20) / (10 + 20), (3 * 30 + 4 * 40) / (30 + 40)]
# [ 50/30, 250/70]
{
  nx_w_avg_y,
  man_w_avg_y
}
```

<!-- livebook:{"output":true} -->

```
{#Nx.Tensor<
   f32[2]
   [1.6666666269302368, 3.5714285373687744]
 >, [1.6666666666666667, 3.5714285714285716]}
```

#### Example with higher rank

The example below will be used through the documentation.

Take a list of numbers of length $n=36$, and turn it into a tensor with the shape `{2,3,2,3}`; this is a tensor of rank 4. We will _name_ the axes, with `names: [:x, :y, :z, :t]`.

```elixir
t =
  Nx.tensor(
    [
      [
        [
          [1, 2, 3],
          [1, -2, 3]
        ],
        [
          [4, -3, 2],
          [-4, 3, 2]
        ],
        [
          [5, -1, 3],
          [-5, 1, 3]
        ]
      ],
      [
        [
          [4, 6, 1],
          [4, -6, 1]
        ],
        [
          [1, 2, 3],
          [1, -2, 3]
        ],
        [
          [4, -3, 2],
          [-4, 3, 2]
        ]
      ]
    ],
    names: [:x, :y, :z, :t]
  )
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[x: 2][y: 3][z: 2][t: 3]
  [
    [
      [
        [1, 2, 3],
        [1, -2, 3]
      ],
      [
        [4, -3, 2],
        [-4, 3, 2]
      ],
      [
        [5, -1, 3],
        [-5, 1, 3]
      ]
    ],
    [
      [
        [4, 6, 1],
        [4, -6, 1]
      ],
      [
        [1, 2, 3],
        [1, -2, 3]
      ],
      [
        [4, -3, 2],
        [-4, 3, 2]
      ]
    ]
  ]
>
```

With the shape `{2,3,2,4}`, you obtain slices of length $2$,then $3$, then $2$ and, finally, $3$.
The picture below will help to understand the aggregations.

<!-- livebook:{"break_markdown":true} -->

$$
\begin{bmatrix}
x=0 &
\begin{bmatrix}
y=0 &
\begin{bmatrix}
z=0 & a_{0,0,0,0} & a_{0,0,0,1} & a_{0,0,0,2} \\
z=1 & a_{0,0,1,0} & a_{0,0,1,1} & a_{0,0,1,2}
\end{bmatrix}\\
y=1 &
\begin{bmatrix}
z=0 & a_{0,1,0,0} & a_{0,1,0,1} & a_{0,1,0,2} \\
z=1 & a_{0,1,1,0} & a_{0,1,1,1} & a_{0,1,1,2} \\
\end{bmatrix}\\
y=2 & \begin{bmatrix}
z=0 & a_{0,2,0,0} & a_{0,2,0,1} & a_{0,2,0,2} \\
z=1 & a_{0,2,1,0} & a_{0,2,1,1} & a_{0,2,1,2} \\
\end{bmatrix}
\end{bmatrix} \\
x=1 &
\begin{bmatrix}
y=0 &
\begin{bmatrix}
z=0 & a_{1,0,0,0} & a_{1,0,0,1} & a_{1,0,0,2}  \\
z=1 & a_{1,0,1,0} & a_{1,0,1,1} & a_{1,0,1,2} \\
\end{bmatrix}\\
y=1 &
\begin{bmatrix}
z=0 & a_{1,1,0,0} & a_{1,1,0,1} & a_{1,1,0,2}  \\
z=1 & a_{1,1,1,0} & a_{1,1,1,1} & a_{1,1,1,2} \\
\end{bmatrix}\\
y=2 &
\begin{bmatrix}
z=0 & a_{1,2,0,0} & a_{1,2,0,1} & a_{1,2,0,2}  \\
z=1 & a_{1,2,1,0} & a_{1,2,1,1} & a_{1,2,1,2} \\
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
\equiv
\begin{bmatrix}
\begin{bmatrix}
\begin{bmatrix}
& 1, 2, 3, & \\
& 1, -2, 3, &
\end{bmatrix} \\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}\\
\begin{bmatrix}
& 5, -1, 3 &\\
& -5, 1, 3 &
\end{bmatrix}
\end{bmatrix} \\
\begin{bmatrix}
\begin{bmatrix}
& 4, 6, 1 & \\
& 4, -6, 1 &
\end{bmatrix}\\
\begin{bmatrix}
& 1, 2, 3 & \\
& 1, -2, 3 &
\end{bmatrix}\\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
$$

<!-- livebook:{"break_markdown":true} -->

Firstly, let's check some characteristics of this tensor using full-tensor aggregation functions.

```elixir
%{
  dimensions: %{
    x: Nx.axis_size(t, :x),
    y: Nx.axis_size(t, :y),
    z: Nx.axis_size(t, :z),
    t: Nx.axis_size(t, :t)
  },
  n: Nx.size(t),
  shape: Nx.shape(t),
  rank: Nx.rank(t),
  analysis: %{
    most_frequent_nb: Nx.mode(t),
    smallest_element: %{
      value: Nx.reduce_min(t),
      position: Nx.argmin(t)
    },
    greatest_element: %{
      value: Nx.reduce_max(t),
      position: Nx.argmax(t)
    },
    no_zero_nb: Nx.all(t)
  }
}
```

<!-- livebook:{"output":true} -->

```
%{
  n: 36,
  shape: {2, 3, 2, 3},
  rank: 4,
  dimensions: %{y: 3, x: 2, t: 3, z: 2},
  analysis: %{
    most_frequent_nb: #Nx.Tensor<
      s64
      3
    >,
    smallest_element: %{
      position: #Nx.Tensor<
        s64
        22
      >,
      value: #Nx.Tensor<
        s64
        -6
      >
    },
    greatest_element: %{
      position: #Nx.Tensor<
        s64
        19
      >,
      value: #Nx.Tensor<
        s64
        6
      >
    },
    no_zero_nb: #Nx.Tensor<
      u8
      1
    >
  }
}
```

#### Single row aggregation, along an axis

> ❗ We are going to use the key `:axis` below. The following functions `argmin`, `argmax`, `median` and `mode` use `:axis: <symbol> or number` in singular mode. All of the other aggregating functions use `axes: [<symbol> or number]`; these have multi-axis aggregation implemented.

When you aggregate along an axis, you are going to reshape the tensor and to aggregate with a function, in other words perform a reduction. It is important to consider the _ordering_ of the axes.

<!-- livebook:{"break_markdown":true} -->

###### Aggregate along the first axis `axis: :x` (equivalently `axis: 0`).

The shape of our original tensor `t` is `[x: 2, y: 3, z: 2, t: 3]`. When we aggregate along the axis `:x`, it collapses and the shape of the resultant tensor `r` is `[y: 3,z: 2,t: 3]`. The rule is that every axis before the selected one will remain in the structure of the tensor as well as every axis after the selected one.

How? You collect from each slice of `x` the elements that have the _same_ remaining indexes "on the right" $y,z,t$. In the case of our tensor, you have 2 `x`-slices so you build a sublist of 2 elements, and then apply the aggregating function on it. This gives you the value of the resulting tensor $r$ at the location $r[y][z][t]$:

$$
\rm{agg}\big(t[0][y][z][t], t[1][y][z][t]\big) = r[y][z][t]
$$

The aggregating function used below is [argmin](https://hexdocs.pm/nx/Nx.html#argmin/2): it returns the index of the smallest element of the sublist. In case of several occurences, the `:tie_break` attribute takes the lowest index by default.

<!-- livebook:{"break_markdown":true} -->

$$
\begin{bmatrix}
\begin{bmatrix}
\begin{bmatrix}
& \boxed{1}, 2, 3, & \\
& 1, -2, 3, &
\end{bmatrix} \\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix} \\
\begin{bmatrix}
& 5, -1, 3 &\\
& -5, 1, 3 &
\end{bmatrix}
\end{bmatrix} \\
\begin{bmatrix}
\begin{bmatrix}
& \boxed{4}, 6, 1 & \\
& 4, -6, 1 &
\end{bmatrix} \\
\begin{bmatrix}
& 1, 2, 3 & \\
& 1, -2, 3 &
\end{bmatrix} \\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
\to
\begin{bmatrix}
\begin{bmatrix}
& \boxed{\rm{agg}(1,4)}, \rm{agg}(2,6), \rm{agg}(3, 1)& \\
& \rm{agg}(1,4), \rm{agg}(-2,-6), \rm{agg}(3,1) &
\end{bmatrix} \\
\begin{bmatrix}
& \rm{agg}(4,1), \rm{agg}(-3,2), \rm{agg}(2,3) & \\
& \rm{agg}(-4,1), \rm{agg}(3,-2), \rm{agg}(2,3) &
\end{bmatrix} \\
\begin{bmatrix}
& \rm{agg}(5,4), \rm{agg}(-1,-3), \rm{agg}(3,2) &\\
& \rm{agg}(-5,-4), \rm{agg}(1,3), \rm{agg}(3,2) &
\end{bmatrix}\\
\end{bmatrix}
\xrightarrow[]{\rm{argmin}}
\begin{bmatrix}
\begin{bmatrix}
& 0,0, 1 & \\
& 0, 1,1 &
\end{bmatrix} \\
\begin{bmatrix}
& 1, 0,0 & \\
& 0,1,0 &
\end{bmatrix} \\
\begin{bmatrix}
& 1, 1,1 &\\
& 0, 0,1 &
\end{bmatrix}
\end{bmatrix}
$$

```elixir
Nx.argmin(t, axis: :x)
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[y: 3][z: 2][t: 3]
  [
    [
      [0, 0, 1],
      [0, 1, 1]
    ],
    [
      [1, 0, 0],
      [0, 1, 0]
    ],
    [
      [1, 1, 1],
      [0, 0, 1]
    ]
  ]
>
```

###### Aggregate along the second axis `axis: :y` (equivalently `axis: 1`) .

The axis `y` will collapse and the shape of the resultant tensor `r` is `[x: 2, z: 2, t: 3]`.

How? The axis $x$ is before the selected one $y$ so it will remain. We therefor apply the procedure as above for _each_ sub $x$-slice.

More precisely, fix the $x$-slice, and let's consider the $x=1$ one. You collect in each sub `y`-slice the elements with the same remaining indexes "on the right", $z,t$ (and $x$ of course). You apply the aggregation and the result will be the value at location $r[1][z][t]$ of the resultant tensor $r$. So the elements of the slice $x=1$ of the resultant tensor $r$ will be:

$$
r[x_{=1}][z][t] = \rm{agg}\big([t[x_{=1}][0][z][t], t[x_{=1}][1][z][t], t[x_{=1}][2][z][t]\big)
$$

You repeat this operation for each $x$-slice.

In the example below, we used again the `argmin` function.

<!-- livebook:{"break_markdown":true} -->

$$
\begin{bmatrix}
\begin{bmatrix}
\begin{bmatrix}
& \boxed{1}, 2, 3, & \\
& 1, -2, 3, &
\end{bmatrix}\\
\begin{bmatrix}
& \boxed{4}, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}\\
\begin{bmatrix}
& \boxed{5}, -1, 3 &\\
& -5, 1, 3 &
\end{bmatrix}
\end{bmatrix}\\
\begin{bmatrix}
\begin{bmatrix}
& 4, 6, 1 & \\
& 4, \boxed{-6}, 1 &
\end{bmatrix}\\
\begin{bmatrix}
& 1, 2, 3 & \\
& 1, \boxed{-2}, 3 &
\end{bmatrix}\\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, \boxed{3}, 2 &
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
\to
\begin{bmatrix}
\begin{bmatrix}
&\boxed{\rm{agg}(1,4,5)}, \rm{agg}(2,-3,-1), \rm{agg}(3, 2,3)& \\
& \rm{agg}(1,-4,-5), \rm{agg}(-2,3,1), \rm{agg}(3,2,3) &
\end{bmatrix}\\
\begin{bmatrix}
& \rm{agg}(4,1,4), \rm{agg}(6,2,-3), \rm{agg}(1,3,2) & \\
& \rm{agg}(4,1,-4), \boxed{\rm{agg}(-6,-2,3)}, \rm{agg}(1,3,2) &
\end{bmatrix}
\end{bmatrix}
\xrightarrow[]{\rm{argmin}}
\begin{bmatrix}
\begin{bmatrix}
& 0, 2, 1 & \\
& 2, 0,1 &
\end{bmatrix}\\
\begin{bmatrix}
& 1, 2,0 & \\
& 2,0,0 &
\end{bmatrix}
\end{bmatrix}
$$

```elixir
Nx.argmin(t, axis: 1)
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[x: 2][z: 2][t: 3]
  [
    [
      [0, 1, 1],
      [2, 0, 1]
    ],
    [
      [1, 2, 0],
      [2, 0, 0]
    ]
  ]
>
```

###### Aggregate along the `axis: :z` axis (equivalently `axis: 2`).

The `z` axis collapse and the shape of the resultant tensor `r`is `[x: 2, y: 3, t: 3]`. You now have understood that we will, for each $x$-slices and each $y$-slices on the left, collect the elements of the slice with the same remain indexes "on the right", thus $t$ here.

$$
r[x][y][t] = \rm{agg}\big(t[x][y][0][t], (t[x][y][1][t] \big)
$$

In the example below, we used the `argmax` function.

<!-- livebook:{"break_markdown":true} -->

$$
\begin{bmatrix}
x=0 &
\begin{bmatrix}
y=0 &
\begin{bmatrix}
\boxed{1}, 2, 3, & \\
\boxed{1}, -2, 3, &
\end{bmatrix} \\
y=1 &
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}\\
y=2 &
\begin{bmatrix}
& 5, -1, 3 &\\
& -5, 1, 3 &
\end{bmatrix}
\end{bmatrix}\\
x=1 &
\begin{bmatrix}
y=0 &
\begin{bmatrix}
& 4, \boxed{6}, 1  \\
& 4,\boxed{-6}, 1
\end{bmatrix}\\
y=1 &
\begin{bmatrix}
& 1, 2, 3 & \\
& 1, -2, 3 &
\end{bmatrix}\\
y=2 &
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
\to
\begin{bmatrix}
x=0 &
\begin{bmatrix}
& \boxed{\rm{agg}(1,1)}, \rm{agg}(2,-2), \rm{agg}(3,3)& \\
& \rm{agg}(4,-4), \rm{agg}(-3,3), \rm{agg}(2,2) & \\
& \rm{agg}(5,-5), \rm{agg}(-1,1), \rm{agg}(3,3)
\end{bmatrix}\\
x=1 &
\begin{bmatrix}
& \rm{agg}(4,4), \boxed{\rm{agg}(6,-6)}, \rm{agg}(1,1) & \\
& \rm{agg}(1,1), \rm{agg}(2,-2), \rm{agg}(3,3) & \\
& \rm{agg}(4,-4), \rm{agg}(-3,3), \rm{agg}(2,2) &
\end{bmatrix}
\end{bmatrix}
\xrightarrow[]{\rm{argmax}}
\begin{bmatrix}
\begin{bmatrix}
0, 0, 0  \\
0, 1, 0 \\
0, 1, 0
\end{bmatrix}\\
\begin{bmatrix}
0, 0, 0  \\
0, 0, 0 \\
0, 1, 0
\end{bmatrix}
\end{bmatrix}
$$

<!-- livebook:{"break_markdown":true} -->

```elixir
Nx.argmax(t, axis: :z)
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[x: 2][y: 3][t: 3]
  [
    [
      [0, 0, 0],
      [0, 1, 0],
      [0, 1, 0]
    ],
    [
      [0, 0, 0],
      [0, 0, 0],
      [0, 1, 0]
    ]
  ]
>
```

###### Aggregate along the last axis, `axis: :t` or (`axis: 3` or `axis: -1`).

This will reshape the tensor to `[x: 2, y: 3, z: 2]`. Since this is the last index, having in mind the matrix picture above, you aggregate along each row:

$$
r[x][y][z] = \rm{agg}\big(t[x][y][z][0],  t[x][y][z][1], t[x][y][z][2]\big)
$$

and repeat this for each $x$-slice, each sub $y$-slice, each sub-sub $z$-slice.

Below is the result for the function `argmin`.

<!-- livebook:{"break_markdown":true} -->

$$
\begin{bmatrix}
x=0 &
\begin{bmatrix}
y=0 &
\begin{bmatrix}
& \boxed{1, 2, 3}, & \\
& 1, -2, 3, &
\end{bmatrix}\\
y=1 &
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}\\
y=2 &
\begin{bmatrix}
& 5, -1, 3 &\\
& -5, 1, 3 &
\end{bmatrix}
\end{bmatrix}\\
x=1 &
\begin{bmatrix}
y=0 &
\begin{bmatrix}
& 4, 6, 1 & \\
& 4,-6, 1 &
\end{bmatrix}\\
y=1 &
\begin{bmatrix}
& 1, 2, 3 & \\
& \boxed{1, -2, 3} &
\end{bmatrix}\\
y=2 &
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
\xrightarrow[]{\rm{argmin}}
\begin{bmatrix}
\begin{bmatrix}
& \boxed{0}, 1 & \\
& 1, 0 & \\
& 1, 0
\end{bmatrix}\\
\begin{bmatrix}
& 2, 1 & \\
& 0, \boxed{1} & \\
& 1, 0
\end{bmatrix}
\end{bmatrix}
$$

<!-- livebook:{"break_markdown":true} -->

```elixir
Nx.argmin(t, axis: 3)
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[x: 2][y: 3][z: 2]
  [
    [
      [0, 1],
      [1, 0],
      [1, 0]
    ],
    [
      [2, 1],
      [0, 1],
      [1, 0]
    ]
  ]
>
```

#### Option [:tie_break](https://hexdocs.pm/nx/Nx.html#argmax/2-tie-breaks)

You have the `:tie_break` option to decide how to operate with you have several occurences of the result. It defaults to `tie_break: :low`.

```elixir
t4 = ~VEC[2 0 0 0 1]

%{
  argmin_with_default: Nx.argmin(t4) |> Nx.to_number(),
  argmin_with_tie_break_high_option: Nx.argmin(t4, tie_break: :high) |> Nx.to_number()
}
```

<!-- livebook:{"output":true} -->

```
%{argmin_with_default: 1, argmin_with_tie_break_high_option: 3}
```

#### Option [:keep_axis](https://hexdocs.pm/nx/Nx.html#argmax/2-keep-axis)

Its default value is `false`. When this option set to `keep_axis: true`, you reduce but keep the dimension of the working axis to $1$. For example, with `t` of shape `{2,3,2,3}`, when you reduce along the third axis, `:z`, we saw that the shape is `{2,3,3}` but when you keep the axis, the shape is `{2,3,1,3}`:

```elixir
Nx.argmin(t, axis: 2, keep_axis: true) ==
  Nx.argmin(t, axis: 2) |> Nx.reshape({2, 3, 1, 3}, names: [:x, :y, :z, :t])
```

<!-- livebook:{"output":true} -->

```
true
```

$$
\begin{bmatrix}
\begin{bmatrix}
0, 0, 0  \\
0, 1, 0 \\
0, 1, 0
\end{bmatrix}\\
\begin{bmatrix}
0, 0, 0  \\
0, 0, 0 \\
0, 1, 0
\end{bmatrix}
\end{bmatrix}
\to
\begin{bmatrix}
\begin{bmatrix}
\begin{bmatrix}
0, 0, 0
\end{bmatrix}\\
\begin{bmatrix}
0, 1, 0
\end{bmatrix}\\
\begin{bmatrix}
0, 1, 0
\end{bmatrix}
\end{bmatrix}\\
\begin{bmatrix}
\begin{bmatrix}
0, 0, 0
\end{bmatrix}\\
\begin{bmatrix}
0, 0, 0
\end{bmatrix}\\
\begin{bmatrix}
0, 1, 0
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
$$

<!-- livebook:{"break_markdown":true} -->

#### Multi-row aggregation

<!-- livebook:{"break_markdown":true} -->

Suppose you want to aggregate along the axis `x` and `z`. Then you should get a tensor of shape `[y: 3, t: 3]`. Given an index `x` and `z`, you aggregate all the numbers with the same indexes $(y,t)$.

$$
r[y][t] = \rm{agg}\big( t[0][y][0][t], t[0][y][1][t], t[1][y][0][t], t[1][y][1][t]\big)
$$

<!-- livebook:{"break_markdown":true} -->

$$
\begin{bmatrix}
\begin{bmatrix}
\begin{bmatrix}
\boxed{1}, 2, 3 & \\
\boxed{1}, -2, 3 &
\end{bmatrix}\\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}\\
\begin{bmatrix}
& 5, -1, 3 &\\
& -5, 1, 3 &
\end{bmatrix}
\end{bmatrix}\\
\begin{bmatrix}
\begin{bmatrix}
\boxed{4}, 6, 1 & \\
\boxed{4},-6, 1 &
\end{bmatrix}\\
\begin{bmatrix}
& 1, 2, 3 & \\
& 1, -2, 3 &
\end{bmatrix}\\
\begin{bmatrix}
& 4, -3, 2 & \\
& -4, 3, 2 &
\end{bmatrix}
\end{bmatrix}
\end{bmatrix}
\to
\begin{bmatrix}
\rm{agg}(1,1,4,4),\rm{agg}(2,-2,6,-6),\rm{agg}(3,3,1,1) &\\
\rm{agg}(4,-4,1,1), \rm{agg}(-3,3,2,-2), \rm{agg}(2,2,3,3) & \\
\rm{agg}(5,-5,4,-4) , \rm{agg}(-1,1,-3,3,  \rm{agg}(3,3,2,2)
\end{bmatrix}
$$

<!-- livebook:{"break_markdown":true} -->

From this, it is easier to undertand what the aggregation returns:

```elixir
Nx.reduce_min(t, axes: [0, -1])
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[y: 3][z: 2]
  [
    [1, -6],
    [-3, -4],
    [-3, -5]
  ]
>
```

```elixir
Nx.reduce_max(t, axes: [:x, :t])
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  s64[y: 3][z: 2]
  [
    [6, 4],
    [4, 3],
    [5, 3]
  ]
>
```

```elixir
Nx.mean(t, axes: [:x, :y])
```

<!-- livebook:{"output":true} -->

```
#Nx.Tensor<
  f32[z: 2][t: 3]
  [
    [3.1666667461395264, 0.5, 2.3333332538604736],
    [-1.1666666269302368, -0.5, 2.3333332538604736]
  ]
>
```
