View Source Scholar.Neighbors.BruteKNN (Scholar v0.3.0)
Brute-Force k-Nearest Neighbor Search Algorithm.
In order to find the k-nearest neighbors the algorithm calculates the distance between the query point and each of the data samples. Therefore, its time complexity is $O(MN)$ for $N$ samples and $M$ query points. It uses $O(BN)$ memory for batch size $B$. Larger batch sizes will lead to faster predictions, but will consume more memory.
Summary
Functions
Fits a brute-force k-NN model.
Computes nearest neighbors of query tensor using brute-force search. Returns the neighbors indices and distances from query points.
Functions
Fits a brute-force k-NN model.
Options
:num_neighbors
(pos_integer/0
) - Required. The number of nearest neighbors.:metric
- The function that measures the pairwise distance between two points. Possible values:{:minkowski, p}
- Minkowski metric. By changing value ofp
parameter (a positive number or:infinity
) we can set Manhattan (1
), Euclidean (2
), Chebyshev (:infinity
), or any arbitrary $L_p$ metric.:cosine
- Cosine metric.Anonymous function of arity 2 that takes two rank-2 tensors.
The default value is
&Scholar.Metrics.Distance.pairwise_minkowski/2
.:batch_size
(pos_integer/0
) - The number of samples in a batch.
Examples
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2)
iex> model.num_neighbors
2
iex> model.data
#Nx.Tensor<
s64[5][2]
[
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6]
]
>
Computes nearest neighbors of query tensor using brute-force search. Returns the neighbors indices and distances from query points.
Examples
iex> data = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> model = Scholar.Neighbors.BruteKNN.fit(data, num_neighbors: 2)
iex> query = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> {neighbors, distances} = Scholar.Neighbors.BruteKNN.predict(model, query)
iex> neighbors
#Nx.Tensor<
u64[3][2]
[
[0, 1],
[1, 2],
[3, 2]
]
>
iex> distances
#Nx.Tensor<
f32[3][2]
[
[1.0, 1.0],
[2.2360680103302, 2.2360680103302],
[1.4142135381698608, 2.0]
]
>