View Source Nearest neighbors

Mix.install([
  {:scholar, "~> 0.3.0"},
  {:explorer, "~> 0.8.1"},
  {:exla, "~> 0.7.2"},
  {:nx, "~> 0.7.2"},
  {:req, "~> 0.4.14"},
  {:kino_vega_lite, "~> 0.1.11"},
  {:kino, "~> 0.12.3"},
  {:kino_explorer, "~> 0.1.18"}
])

Setup

We will use Explorer in this notebook, so let's define aliases for its main modules:

require Explorer.DataFrame, as: DF
require Explorer.Series, as: S

And let's configure EXLA as our default backend (where our tensors are stored) and compiler (which compiles Scholar code) across the notebook and all branched sections:

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)

Definition of nearest neighbors

We introduced k-nearest neighbors (KNN) in a previous notebook alongside some example problems. In that guide, we used the default algorithm for KNN, which uses a brute-force approach. In this document, we will introduce alternative algorithms and their trade-offs, but to do so, we need to define exactly what "nearest neighbors" are.

Consider a set of points ${p_1, p_2, \ldots, p_N}$, where each point $p_i$ is represented as a vector of size $M$, with real-number coordinates: $p_i = (p_{i1}, p_{i2}, \ldots, p_{iM})$. We say that the dimension of the space is $M$. For convenience, instead of listing the set of points as ${p_1, p_2, \ldots, p_N}$, we can treat it as a matrix $P$ of size $N \times M$, where $P_{ij}$ represents the $j$th coordinate of the $i$th point.

A metric between two points $a$ and $b$ is defined as a function d that takes two points and returns a non-negative real number, subject to the following conditions for arbitrary points $a$, $b$, and $c$:

  • $d(a,b) = 0 \iff a = b$ (identity)
  • $d(a,b) = d(b, a)$ (symmetry)
  • $d(a,b) \leq d(a, c) + d(b, c)$ (triangle inequality)

Examples of such metrics include the Euclidean metric and the Manhattan metric.

Given a set of points $P$, a point $q$, and a metric $d$, the $k$ nearest neighbors from $P$ to point $q$ are defined as a set $S \subseteq P$, with $\lvert S \rvert = k$, where the distance $d$ between the points in $S$ and $q$ is minimized. In simpler terms, $S$ is a set of $k$ points from $P$ with the smallest distance to point $q$.

It's especially interesting to consider the case when point $q$ is part of the set $P$. In the remainder of the notebook, we will assume that $q$ is indeed in $P$, and if this is not the case, we will note it explicitly.

Brute force approach

Now let's consider how we can compute the $k$-nearest neighbors. The simplest method is to compute the distance using a metric $d$ for each pair of points in the dataset. This can be represented as a matrix $D$ of size $N \times N$, where $D_{ij}$ indicates the distance between the $i$th and $j$th points. Given the symmetry of metric $d$, the matrix $D$ is also symmetric: $D_{ij} = D_{ji}$. To find the nearest neighbors, you would sort the distances for each point (sorting the rows of matrix $D$), and then return only the first $k$ columns from each row of the sorted matrix.

While this method accurately computes the $k$-nearest neighbors, it has a major drawback: time complexity. The time complexity of this approach is $O(MN^2 + N\log(N)) = O(MN^2)$, where $N$ is the number of points, and $M$ represents the number of dimensions (features) of each point. Even if we set $M = 1$, when $N$ is very large (like the number of songs on Spotify, which is on the order of $10^8$), the number of operations required becomes unfeasible, even if $k$ is relatively small. Given this, for large values of $N$, we need to consider other approaches. However, the brute-force method can still be useful in certain situations. If $N$ is small, the distance matrix $D$ can be computed in a reasonable amount of time. Additionally, if $k \approx N$, this approach is optimal since you always need to compute $O(Nk)$, which leads to $O(N^2)$ operations. Computing the distance matrix is optimized in many frameworks, giving it an advantage over other methods when $k$ is on the order of magnitude of $N$.

Brute force is the default algorithm used by Scholar.Neighbors.KNNClassifier and Scholar.Neighbors.KNNRegressor. However, different algorithms can be given via the :algorithm option. Let's take a look at one of them.

KDtrees

For large datasets, we need a much faster approach. It turns out that there is a data structure that allows querying for a single neighbor in $O(\log(N))$ average time, resulting in a total computation time of $O(kN\log(N))$, which is efficient for small values of $k$. But what is this data structure that allows querying faster than the brute-force approach? It's called a KDTree. KDTree, short for "K-Dimensional Tree", is a tree structure that organizes points in a K-dimensional space. More precisely, it's a binary tree where each non-leaf node has an associated splitting hyperplane. All points on the left side of this splitting hyperplane belong to the left subtree, while those on the right side belong to the right subtree. The dimension along which the split occurs is consistent across each level of the tree, and it changes as you move to different levels.

To better understand this, let's look at two graphical examples of KDTrees.

KDTree
Figure 1: A 3-dimensional KDtree.

Here, the first split (a red vertical plane) divides the root cell (white) into two subcells. Then, each of these subcells is further split (by the green horizontal planes) into two additional subcells. Finally, the four resulting subcells are further split (by four blue vertical planes), creating a total of eight subcells. Since there are no more sectors to split, the algorithm terminates with these eight subcells.

Here's another, more abstract example.

KDtree
Figure 2: A different example of 3-dimensional KDtree.

In this example, the dimensions are abstract values like name, salary, and age. However, because we can define a linear order for each feature (such as a natural order for salary and age, and a lexicographic order for names), we can treat them like any other metric space.

Now, let's move on to the code and implementation of KDTrees in Scholar!

KDtrees in action: a practical example

In this example, we'll use data from the National Health and Nutrition Examination Survey (NHANES) 2013-2014. The dataset consists of 7 features that are biomedical indicators. For more information, click the link above.

Now, let's read the data into an Explorer dataframe.

digits_url =
  "https://archive.ics.uci.edu/static/public/887/national+health+and+nutrition+health+survey+2013-2014+(nhanes)+age+prediction+subset.zip"

data =
  Req.get!(digits_url).body

{:ok, [{_, data}]} = :zip.extract(data, [:memory])
df_data = DF.load_csv!(data)
#Explorer.DataFrame<
  Polars[2278 x 10]
  SEQN f64 [73564.0, 73568.0, 73576.0, 73577.0, 73580.0, ...]
  age_group string ["Adult", "Adult", "Adult", "Adult", "Adult", ...]
  RIDAGEYR f64 [61.0, 26.0, 16.0, 32.0, 38.0, ...]
  RIAGENDR f64 [2.0, 2.0, 1.0, 1.0, 2.0, ...]
  PAQ605 f64 [2.0, 2.0, 2.0, 2.0, 1.0, ...]
  BMXBMI f64 [35.7, 20.3, 23.2, 28.9, 35.9, ...]
  LBXGLU f64 [110.0, 89.0, 89.0, 104.0, 103.0, ...]
  DIQ010 f64 [2.0, 2.0, 2.0, 2.0, 2.0, ...]
  LBXGLT f64 [150.0, 80.0, 68.0, 84.0, 81.0, ...]
  LBXIN f64 [14.91, 3.85, 6.14, 16.15, 10.92, ...]
>

In this analysis, we want to determine if the most similar people, in terms of medical parameters, belong to the same age group. In the dataset, we categorize people into two groups: adults (18-64 years) and seniors (65+ years). Intuitively, we might expect that people with similar medical results are often in the same age group, but we want to quantify the extent to which this is true. To do this, we compute the 5 closest results for each patient and then calculate the average for these two groups.

First, we need to separate the medical results from the data that indicates the age group.

x = df_data |> DF.discard(["age_group", "SEQN", "RIDAGEYR"]) |> Nx.stack(axis: 1)
y = Nx.stack(S.cast(df_data["age_group"], :category), axis: 1)

x = Scholar.Preprocessing.StandardScaler.fit_transform(x)

{x, y}
{#Nx.Tensor<
   f64[2278][7]
   EXLA.Backend<host:0, 0.482371013.981598228.186826>
   [
     [-0.7098659661342087, -0.7098659661342087, -0.028236969791631317, 1.4745830429161875, -0.7098659661342087, 2.2836382314533714, -0.4487434040338326],
     [-0.7098659661342087, -0.7098659661342087, -0.33972321737844713, 1.049829068934166, -0.7098659661342087, 0.8677916515132996, -0.672447163664364],
     [-0.7300923458476383, -0.7098659661342087, -0.2810667162095013, 1.049829068934166, -0.7098659661342087, 0.6250750949521445, -0.6261287541206102],
     [-0.7300923458476383, -0.7098659661342087, -0.16577635184295264, 1.35322476463561, -0.7098659661342087, 0.9486971703670181, -0.42366269318918],
     [-0.7098659661342087, -0.7300923458476383, -0.024191693848945484, 1.3329983849221803, -0.7098659661342087, 0.8880180312267293, -0.5294466590904167],
     [-0.7300923458476383, -0.7098659661342087, -0.2729761643241294, 1.4745830429161875, -0.7098659661342087, 1.2723192457818915, -0.627342336903416],
     [-0.7300923458476383, -0.7098659661342087, 0.032442169348657464, 1.150960967501314, -0.7098659661342087, 3.3354099765517105, -0.3233398498105692],
     ...
   ]
 >,
 #Nx.Tensor<
   u32[2278][1]
   EXLA.Backend<host:0, 0.482371013.981598228.186824>
   [
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [1],
     [0],
     [0],
     [0],
     [1],
     [0],
     [1],
     [1],
     [0],
     [0],
     [1],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [1],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [1],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     [0],
     ...
   ]
 >}

Once the data is preprocessed, we can move on to the KDTree model. First, we need to set up the model using Scholar.Neighbors.KDTree.fit/2. This method initializes the tree structure for further processing.

num_neighbors = 6

knn_model = Scholar.Neighbors.KDTree.fit(x, num_neighbors: 6, metric: {:minkowski, 2})
%Scholar.Neighbors.KDTree{
  levels: 12,
  indices: #Nx.Tensor<
    u32[2278]
    EXLA.Backend<host:0, 0.482371013.981598226.187105>
    [257, 914, 1129, 225, 978, 301, 1131, 688, 481, 1786, 2135, 472, 307, 1606, 1498, 476, 716, 361, 367, 1627, 1615, 1562, 1660, 747, 776, 748, 811, 1690, 1803, 1678, 1687, 244, 925, 491, 1816, 206, 695, 229, 977, 1198, 1865, 1291, 1800, 1549, 2255, 1109, 2032, 597, ...]
  >,
  data: #Nx.Tensor<
    f64[2278][7]
    EXLA.Backend<host:0, 0.482371013.981598228.186826>
    [
      [-0.7098659661342087, -0.7098659661342087, -0.028236969791631317, 1.4745830429161875, -0.7098659661342087, 2.2836382314533714, -0.4487434040338326],
      [-0.7098659661342087, -0.7098659661342087, -0.33972321737844713, 1.049829068934166, -0.7098659661342087, 0.8677916515132996, -0.672447163664364],
      [-0.7300923458476383, -0.7098659661342087, -0.2810667162095013, 1.049829068934166, -0.7098659661342087, 0.6250750949521445, -0.6261287541206102],
      [-0.7300923458476383, -0.7098659661342087, -0.16577635184295264, 1.35322476463561, -0.7098659661342087, 0.9486971703670181, -0.42366269318918],
      [-0.7098659661342087, -0.7300923458476383, -0.024191693848945484, 1.3329983849221803, -0.7098659661342087, 0.8880180312267293, -0.5294466590904167],
      [-0.7300923458476383, -0.7098659661342087, -0.2729761643241294, 1.4745830429161875, -0.7098659661342087, 1.2723192457818915, -0.627342336903416],
      [-0.7300923458476383, -0.7098659661342087, 0.032442169348657464, 1.150960967501314, -0.7098659661342087, ...],
      ...
    ]
  >,
  num_neighbors: 6,
  metric: #Function<5.83620768/2 in Scholar.Neighbors.Utils.metric/1>
}

Now, we need to run a query to calculate the nearest neighbors for a given set of data. To compute the nearest neighbors, we call Scholar.Neighbors.KDTree.predict.

{indices, distances} = Scholar.Neighbors.KDTree.predict(knn_model, x)
{#Nx.Tensor<
   s64[2278][6]
   EXLA.Backend<host:0, 0.482371013.981598225.187162>
   [
     [0, 1367, 370, 119, 1158, 1423],
     [1, 432, 1078, 1370, 1252, 358],
     [2, 1932, 2179, 1516, 1089, 1821],
     [3, 823, 510, 770, 552, 1066],
     [4, 1029, 2103, 333, 823, 1063],
     [5, 99, 1830, 1459, 701, 1767],
     [6, 1060, 402, 557, 2274, 1733],
     [7, 1631, 1227, 1493, 2209, 1991],
     [8, ...],
     ...
   ]
 >,
 #Nx.Tensor<
   f64[2278][6]
   EXLA.Backend<host:0, 0.482371013.981598225.187163>
   [
     [0.0, 0.10564987631681173, 0.13647696316587582, 0.16467068264346493, 0.17265036785358587, 0.17355553880586283],
     [0.0, 0.05084833013473313, 0.05862161772775837, 0.059033314891497844, 0.061528101864556185, 0.07051004160599812],
     [0.0, 0.0634183290913442, 0.06834720866562938, 0.06914612071933461, 0.0702263282472859, 0.07657788563767927],
     [0.0, 0.08520661043231542, 0.08772188110415277, 0.09315472121407559, 0.10080775703280971, 0.10357822211490808],
     [0.0, 0.08856498626863332, 0.10911471056995574, 0.11729225102681795, 0.12629118906993017, 0.13448641399541536],
     [0.0, 0.06737206394520054, 0.07237087978074229, 0.09445004662593554, 0.10244796099940952, 0.10666419265073056],
     [0.0, 0.1772284512996439, 0.18087450661041063, 0.22321979638902212, 0.26995822503187694, 0.3492397094272336],
     [0.0, 0.10684507307976371, 0.1450485290844591, 0.15468671853000548, 0.1874949397442543, 0.1922686292168594],
     ...
   ]
 >}

Before we compute the statistics from the results, we need to now what are the proportions of each group initially.

seniors_original = Nx.divide(Nx.sum(Nx.equal(y, 1)), Nx.size(y))
adults_original = Nx.subtract(1, seniors_original)
{seniors_original, adults_original}
{#Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.482371013.981598227.187001>
   0.1597892940044403
 >,
 #Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.482371013.981598227.187003>
   0.8402106761932373
 >}

As we can see, about 16% of the people are seniors, while 84% are adults. Now, let's calculate the average number of people within the same age group among the nearest neighbors. We should drop the first column since we're not interested in considering the point itself when calculating its neighborhood.

knn = indices[[.., 1..-1//1]]
age_groups = Nx.take(y, knn)

seniors =
  Nx.divide(
    Nx.sum(Nx.multiply(Nx.equal(age_groups, 1), Nx.new_axis(Nx.equal(y, 1), 1))),
    Nx.sum(Nx.equal(y, 1))
  )
  |> Nx.divide(num_neighbors - 1)

adults =
  Nx.divide(
    Nx.sum(Nx.multiply(Nx.equal(age_groups, 0), Nx.new_axis(Nx.equal(y, 0), 1))),
    Nx.sum(Nx.equal(y, 0))
  )
  |> Nx.divide(num_neighbors - 1)

{seniors, adults}
{#Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.482371013.981598225.187182>
   0.28406593203544617
 >,
 #Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.482371013.981598225.187191>
   0.8579937219619751
 >}

Our intuition was correct: for seniors, 28% of their nearest neighbors in terms of medical results are also seniors, while for adults, nearly 86% of their nearest neighbors are also adults. This indicates that age should always be considered to ensure proper treatment when assessing medical results and drawing conclusions.

Approximated nearest neighbors with LargeVis

The approaches we have seen so far have been exact: they always find the closest k nearest neighbors. However, there are situations where exact computations are too expensive and we want to perform approximate computations.

In the previous example, KDTree was a perfect choice for computing K-Nearest Neighbors. However, when the number of features grows, the efficiency of KDTree decreases significantly, so it's best suited for data with a smaller number of features (<15). For data with many samples but also many features, approximated versions of nearest neighbors are often preferred. They may not give exact solutions, but the deviation is usually minor and accepted as a tradeoff for faster computation.

LargeVis is an approximated nearest neighbors algorithm that uses the concept of a Random Projection Forest. A Random Projection Forest consists of multiple trees. Each tree is built using a divide-and-conquer approach. The entire dataset is used to start the construction, and at each node, the data is projected onto a random hyperplane and split as follows: points with a projection smaller than or equal to the median go into the left subtree, while those with a projection greater than the median go into the right subtree. This process is then recursively applied to the left and right subtrees.

This structure allows us to quickly create an approximated nearest neighbors graph, which can then be used to find the nearest neighbors efficiently. To better understand this, below is a figure illustrating the process of creating a tree in a random projection forest. We need to note that in Scholar implementation in each step points are splitted in half, but on the figure below this condition is not satisifed. However, the general idea of splitting points is the same.

Random Projection Forest
Figure 3: A process of creating a tree in Random Projection Forest (Source: http://www.math.umassd.edu/~dyan/Slides_rpForests.pdf)

On the next figure, we have a comparison of how space is partitioned using KDTree and Random Projection Forest. This time splitting is the same as the output of Scholar implementation of Random Projection Forest.

Figure 4: Left: A spatial partitioning of $\mathbf{R}^2$ inducedby a KDTree with three levels. The dots are data points; the cross marks a query point $q$ for KNN. Right: Partitioning induced by an Random Projection Tree.

Here's a brief explanation of LargeVis. First, we compute an approximated nearest neighbors graph using a Random Projection Forest. Then, we refine the graph by examining the neighbors of each point's neighbors for a fixed number of iterations. Now, let's use LargeVis in practice.

In this example, we will use a dataset on online news popularity. This dataset summarizes a diverse set of features about articles published by Mashable over a two-year period. To read more about this dataset, click here. Usually, the dataset is used to predict the number of shares on social networks, but we'll use it differently. KNN and Approximated Nearest Neighbors can be valuable in data mining, particularly in business scenarios. Imagine you're an editor at an online news agency, and you want to create a series of articles that boost your reach. One strategy could be to examine groups of the most popular articles and analyze them to identify specific patterns. We can use Approximated Nearest Neighbors to do this. For each news article, we'll calculate its nearest neighbors and then sort these groups by the total number of shares they receive on social networks.

The first step is to load the dataset from the Internet archive.

digits_url =
  "https://archive.ics.uci.edu/static/public/332/online+news+popularity.zip"

raw_data =
  Req.get!(digits_url).body
<<80, 75, 3, 4, 10, 0, 0, 0, 0, 0, 76, 147, 191, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 28,
  0, 79, 110, 108, 105, 110, 101, 78, 101, 119, 115, 80, 111, 112, 117, 108, 97, 114, 105, 116, 121,
  ...>>

Next, we load the data into a dataframe.

{:ok, [_, {_, data}]} = :zip.extract(raw_data, [:memory])
data = String.replace(data, ", ", ",")
df_data = DF.load_csv!(data, delimiter: ",")
#Explorer.DataFrame<
  Polars[39644 x 61]
  url string ["http://mashable.com/2013/01/07/amazon-instant-video-browser/", "http://mashable.com/2013/01/07/ap-samsung-sponsored-tweets/", "http://mashable.com/2013/01/07/apple-40-billion-app-downloads/", "http://mashable.com/2013/01/07/astronaut-notre-dame-bcs/", "http://mashable.com/2013/01/07/att-u-verse-apps/", ...]
  timedelta f64 [731.0, 731.0, 731.0, 731.0, 731.0, ...]
  n_tokens_title f64 [12.0, 9.0, 9.0, 9.0, 13.0, ...]
  n_tokens_content f64 [219.0, 255.0, 211.0, 531.0, 1072.0, ...]
  n_unique_tokens f64 [0.663594466988, 0.604743080614, 0.575129530699, 0.503787877834, 0.41564561695, ...]
  n_non_stop_words f64 [0.999999992308, 0.999999993289, 0.999999991597, 0.999999996904, 0.999999998565, ...]
  n_non_stop_unique_tokens f64 [0.815384609112, 0.79194630341, 0.66386554064, 0.665634672862, 0.540889525766, ...]
  num_hrefs f64 [4.0, 3.0, 3.0, 9.0, 19.0, ...]
  num_self_hrefs f64 [2.0, 1.0, 1.0, 0.0, 19.0, ...]
  num_imgs f64 [1.0, 1.0, 1.0, 1.0, 20.0, ...]
  num_videos f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  average_token_length f64 [4.6803652968, 4.9137254902, 4.39336492891, 4.40489642185, 4.6828358209, ...]
  num_keywords f64 [5.0, 4.0, 6.0, 7.0, 7.0, ...]
  data_channel_is_lifestyle f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  data_channel_is_entertainment f64 [1.0, 0.0, 0.0, 1.0, 0.0, ...]
  data_channel_is_bus f64 [0.0, 1.0, 1.0, 0.0, 0.0, ...]
  data_channel_is_socmed f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  data_channel_is_tech f64 [0.0, 0.0, 0.0, 0.0, 1.0, ...]
  data_channel_is_world f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_min_min f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_max_min f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_avg_min f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_min_max f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_max_max f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_avg_max f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_min_avg f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_max_avg f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  kw_avg_avg f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  self_reference_min_shares f64 [496.0, 0.0, 918.0, 0.0, 545.0, ...]
  self_reference_max_shares f64 [496.0, 0.0, 918.0, 0.0, 1.6e4, ...]
  self_reference_avg_sharess f64 [496.0, 0.0, 918.0, 0.0, 3151.15789474, ...]
  weekday_is_monday f64 [1.0, 1.0, 1.0, 1.0, 1.0, ...]
  weekday_is_tuesday f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  weekday_is_wednesday f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  weekday_is_thursday f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  weekday_is_friday f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  weekday_is_saturday f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  weekday_is_sunday f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  is_weekend f64 [0.0, 0.0, 0.0, 0.0, 0.0, ...]
  LDA_00 f64 [0.500331204081, 0.799755687423, 0.217792288518, 0.0285732164707, 0.0286328101715, ...]
  LDA_01 f64 [0.378278929586, 0.0500466753998, 0.033334456999, 0.419299641782, 0.0287935517322, ...]
  LDA_02 f64 [0.0400046751006, 0.0500962518137, 0.0333514249339, 0.49465082574, 0.0285751849112, ...]
  LDA_03 f64 [0.0412626477296, 0.0501006734234, 0.0333335358046, 0.0289047184252, 0.028571675324, ...]
  LDA_04 f64 [0.0401225435029, 0.0500007119405, 0.682188293744, 0.0285715975818, 0.885426777861, ...]
  global_subjectivity f64 [0.521617145481, 0.341245791246, 0.702222222222, 0.42984968735, 0.513502122877, ...]
  global_sentiment_polarity f64 [0.0925619834711, 0.148947811448, 0.323333333333, 0.100704665705, 0.281003475691, ...]
  global_rate_positive_words f64 [0.0456621004566, 0.043137254902, 0.0568720379147, 0.0414312617702, 0.0746268656716, ...]
  global_rate_negative_words f64 [0.013698630137, 0.0156862745098, 0.00947867298578, 0.0207156308851, 0.0121268656716, ...]
  rate_positive_words f64 [0.769230769231, 0.733333333333, 0.857142857143, 0.666666666667, 0.860215053763, ...]
  rate_negative_words f64 [0.230769230769, 0.266666666667, 0.142857142857, 0.333333333333, 0.139784946237, ...]
  avg_positive_polarity f64 [0.378636363636, 0.286914600551, 0.495833333333, 0.385965171192, 0.411127435065, ...]
  min_positive_polarity f64 [0.1, 0.0333333333333, 0.1, 0.136363636364, 0.0333333333333, ...]
  max_positive_polarity f64 [0.7, 0.7, 1.0, 0.8, 1.0, ...]
  avg_negative_polarity f64 [-0.35, -0.11875, -0.466666666667, -0.369696969697, -0.220192307692, ...]
  min_negative_polarity f64 [-0.6, -0.125, -0.8, -0.6, -0.5, ...]
  max_negative_polarity f64 [-0.2, -0.1, -0.133333333333, -0.166666666667, -0.05, ...]
  title_subjectivity f64 [0.5, 0.0, 0.0, 0.0, 0.454545454545, ...]
  title_sentiment_polarity f64 [-0.1875, 0.0, 0.0, 0.0, 0.136363636364, ...]
  abs_title_subjectivity f64 [0.0, 0.5, 0.5, 0.5, 0.0454545454545, ...]
  abs_title_sentiment_polarity f64 [0.1875, 0.0, 0.0, 0.0, 0.136363636364, ...]
  shares s64 [593, 711, 1500, 1200, 505, ...]
>

Next, we split the data into two parts: the target variable (shares, or $y$) and the rest of the data ($x$). We separate the shares from $x$ to avoid introducing bias. We also remove URLs because tensors do not support non-numerical data types, and we exclude the timedelta feature, which is an auxiliary feature used in creating the dataset. Finally, we apply standard scaling to $x$.

x = df_data |> DF.discard(["url", "timedelta", "shares"]) |> Nx.stack(axis: 1)
y = Nx.stack(df_data[["shares"]], axis: 1)

x = Scholar.Preprocessing.StandardScaler.fit_transform(x)
{x, y}
{#Nx.Tensor<
   f64[39644][58]
   EXLA.Backend<host:0, 0.482371013.981598228.187092>
   [
     [-0.16777552811051025, -0.16587029232809064, -0.1678798688132422, -0.16787677252413957, -0.16787847173106735, -0.1678491604112801, -0.16786756848647255, -0.16787677252406877, -0.167885976561665, -0.16784289830350915, -0.16783995637368385, -0.167885976561665, -0.16787677252406877, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.16332077391393496, -0.16332077391393496, -0.16332077391393496, -0.16787677252406877, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.167885976561665, -0.16788137149445206, -0.16788249486817522, -0.16788560835713134, -0.167885596778704, -0.16788560727226615, -0.16788117557784715, -0.16788512461768915, -0.16788555628597568, -0.1678858504789582, -0.16787889653274482, -0.16788385255298896, -0.16788249157833882, ...],
     ...
   ]
 >,
 #Nx.Tensor<
   s64[39644][1]
   EXLA.Backend<host:0, 0.482371013.981598227.186805>
   [
     [593],
     [711],
     [1500],
     [1200],
     [505],
     [855],
     [556],
     [891],
     [3600],
     [710],
     [2200],
     [1900],
     [823],
     [10000],
     [761],
     [1600],
     [13600],
     [3100],
     [5700],
     [17100],
     [2800],
     [598],
     [445],
     [1500],
     [852],
     [783],
     [1500],
     [1800],
     [462],
     [425],
     [3200],
     [3900],
     [480],
     [573],
     [3600],
     [343],
     [7700],
     [2600],
     [504],
     [2100],
     [1100],
     [401],
     [4200],
     [1200],
     [1500],
     [575],
     [507],
     [1200],
     ...
   ]
 >}

Now, let's move on to computing the approximated nearest neighbors. Notice that this dataset has significantly more features than the previous one (58 compared to 7).

num_neighbors = 6
{indices, _dist} = Scholar.Neighbors.LargeVis.fit(x, num_neighbors: num_neighbors)
{#Nx.Tensor<
   u32[39644][6]
   EXLA.Backend<host:0, 0.482371013.981598228.187096>
   [
     [0, 65, 21, 69, 2, 25],
     [1, 66, 28, 61, 9, 32],
     [2, 25, 69, 21, 57, 65],
     [3, 56, 33, 63, 53, 72],
     [4, 62, 43, 40, 60, 44],
     [5, 39, 35, 30, 18, 14],
     [6, 51, 74, 7, 64, 44],
     [7, 74, 64, 6, 44, 60],
     [8, ...],
     ...
   ]
 >,
 #Nx.Tensor<
   f64[39644][6]
   EXLA.Backend<host:0, 0.482371013.981598228.187097>
   [
     [0.0, 0.0018657036251906384, 0.0036749279899371094, 0.005279428978234045, 0.006727966127223153, 0.006837100863281978],
     [0.0, 6.756593401603524e-5, 1.2678390136218736e-4, 1.6633686250093157e-4, 2.2233078504659554e-4, 3.514499996064942e-4],
     [0.0, 5.173768744997381e-4, 0.001809295719831838, 0.003793044425862508, 0.006135524442190683, 0.006203436890693119],
     [0.0, 5.252092785567814e-4, 6.090372028619397e-4, 6.319163895009465e-4, 7.10570116086926e-4, 8.120833196596075e-4],
     [0.0, 5.067837852972306e-5, 9.174324571151536e-5, 3.706506657349229e-4, 4.427973160256605e-4, 4.6200572899281997e-4],
     [0.0, 0.03510952981931922, 0.03666978946624632, 0.038386790160971074, 0.05585061734567715, 0.05899400609953473],
     [0.0, 1.4980733767922953e-4, 2.6223120133330976e-4, 2.703271990114324e-4, 4.1551509313155296e-4, 5.712522600732687e-4],
     [0.0, 2.0471372823266868e-5, 1.485758818753876e-4, 2.703271990114324e-4, 3.053468283568113e-4, 3.2301368793209305e-4],
     ...
   ]
 >}

Finally, we compute the nearest neighbors and sort them in descending order based on the sum of shares.

shares_of_neighbors = Nx.take(y, indices) |> Nx.squeeze(axes: [-1]) |> Nx.sum(axes: [-1])
ord = Nx.argsort(shares_of_neighbors, direction: :desc)
Nx.take(indices, ord)
#Nx.Tensor<
  u32[39644][6]
  EXLA.Backend<host:0, 0.482371013.981598228.187107>
  [
    [8055, 5407, 9365, 6100, 8717, 8217],
    [8065, 8717, 9365, 5407, 6100, 8055],
    [5407, 9365, 8055, 8717, 8065, 6100],
    [8717, 8065, 9365, 6100, 5407, 8055],
    [6100, 8717, 9365, 5407, 8065, 8055],
    [9365, 8065, 5407, 8717, 6100, 8055],
    [16467, 14218, 11843, 26259, 16268, 16334],
    [16334, 16268, 16467, 15205, 26259, 14218],
    [16268, 16334, ...],
    ...
  ]
>

Notice that the first six groups are quite similar. The condition that a neighbor of your neighbor is also your neighbor is often satisfied, which is why LargeVis uses this fact to refine the results of Random Projection Forests.