View Source Scholar.Impute.KNNImputter (Scholar v0.4.0)
Imputer for completing missing values using k-Nearest Neighbors.
Each sample's missing values are imputed using the mean value from
n_neighbors
nearest neighbors found in the training set. Two samples are
close if the features that neither is missing are close.
Summary
Functions
Imputter for completing missing values using k-Nearest Neighbors.
Impute all missing values in x
using fitted imputer.
Functions
Imputter for completing missing values using k-Nearest Neighbors.
Preconditions:
- The number of neighbors must be less than the number of valid rows - 1.
- A valid row is a row with more than 1 non-NaN values. Otherwise it is better to use a simpler imputter.
- When you set a value different than :nan in
missing_values
there should be no NaNs in the input tensor
Options
:missing_values
- The placeholder for the missing values. All occurrences of:missing_values
will be imputed.The default value expects there are no NaNs in the input tensor.
The default value is
:nan
.:num_neighbors
(pos_integer/0
) - The number of nearest neighbors. The default value is2
.
Return Values
The function returns a struct with the following parameters:
:missing_values
- the same value as in the:missing_values
option:statistics
- The imputation fill value for each feature. Computing statistics can result in values.
Examples
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
iex> Scholar.Impute.KNNImputter.fit(x, num_neighbors: 2)
%Scholar.Impute.KNNImputter{
statistics: Nx.tensor(
[
[:nan, :nan],
[:nan, :nan],
[:nan, 8.0],
[7.5, :nan],
[:nan, :nan]
]
),
missing_values: :nan
}
Impute all missing values in x
using fitted imputer.
Return Values
The function returns input tensor with NaN replaced with values saved in fitted imputer.
Examples
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
iex> imputer = Scholar.Impute.KNNImputter.fit(x, num_neighbors: 2)
iex> Scholar.Impute.KNNImputter.transform(imputer, x)
Nx.tensor(
[
[40.0, 2.0],
[4.0, 5.0],
[7.0, 8.0],
[7.5, 8.0],
[11.0, 11.0]
]
)