View Source Scholar.Linear.IsotonicRegression (Scholar v0.3.0)
Isotonic regression is a method of fitting a free-form line to a set of observations by solving a convex optimization problem. It is a form of regression analysis that can be used as an alternative to polynomial regression to fit nonlinear data.
Time complexity of isotonic regression is $O(N^2)$ where $N$ is the number of points.
Summary
Functions
Fits a isotonic regression model for sample inputs x
and
sample targets y
.
Makes predictions with the given model
on input x
and interpolating function
.
Preprocesses the model
for prediction.
Types
@type t() :: %Scholar.Linear.IsotonicRegression{ cutoff_index: Nx.Tensor.t(), increasing: Nx.Tensor.t(), preprocess: tuple() | Scholar.Interpolation.Linear.t(), x_max: Nx.Tensor.t(), x_min: Nx.Tensor.t(), x_thresholds: Nx.Tensor.t(), y_thresholds: Nx.Tensor.t() }
Functions
Fits a isotonic regression model for sample inputs x
and
sample targets y
.
Options
:y_min
(float/0
) - Lower bound on the lowest predicted value. If if not provided, the lower bound is set toNx.Constant.neg_infinity()
.:y_max
(float/0
) - Upper bound on the highest predicted value. If if not provided, the lower bound is set toNx.Constant.infinity()
.:increasing
- Whether the isotonic regression should be fit with the constraint that the function is monotonically increasing. Iffalse
, the constraint is that the function is monotonically decreasing. If:auto
, the constraint is determined automatically based on the data. The default value is:auto
.:out_of_bounds
- How to handle out-of-bounds points. If:clip
, out-of-bounds points are mapped to the nearest valid value. If:nan
, out-of-bounds points are replaced withNx.Constant.nan()
. The default value is:nan
.:sample_weights
- The weights for each observation. If not provided, all observations are assigned equal weight.
Return Values
The function returns a struct with the following parameters:
:x_min
- Minimum value of input tensorx
.:x_max
- Maximum value of input tensorx
.:x_thresholds
- Thresholds used for predictions.:y_thresholds
- Predicted values associated with each threshold.:increasing
- Whether the isotonic regression is increasing.:cutoff_index
- The index of the last valid threshold. Rest elements are placeholders for the sake of preserving shape of tensor.:preprocess
- Interpolation function to be applied on input tensorx
. Beforepreprocess/1
is applied it is set to {}
Examples
iex> x = Nx.tensor([1, 4, 7, 9, 10, 11])
iex> y = Nx.tensor([1, 3, 6, 8, 9, 10])
iex> Scholar.Linear.IsotonicRegression.fit(x, y)
%Scholar.Linear.IsotonicRegression{
x_min: Nx.tensor(
1.0
),
x_max: Nx.tensor(
11.0
),
x_thresholds: Nx.tensor(
[1.0, 4.0, 7.0, 9.0, 10.0, 11.0]
),
y_thresholds: Nx.tensor(
[1.0, 3.0, 6.0, 8.0, 9.0, 10.0]
),
increasing: Nx.u8(1),
cutoff_index: Nx.tensor(
5
),
preprocess: {}
}
Makes predictions with the given model
on input x
and interpolating function
.
Examples
iex> x = Nx.tensor([1, 4, 7, 9, 10, 11])
iex> y = Nx.tensor([1, 3, 6, 8, 9, 10])
iex> model = Scholar.Linear.IsotonicRegression.fit(x, y)
iex> model = Scholar.Linear.IsotonicRegression.preprocess(model)
iex> to_predict = Nx.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
iex> Scholar.Linear.IsotonicRegression.predict(model, to_predict)
#Nx.Tensor<
f32[10]
[1.0, 1.6666667461395264, 2.3333332538604736, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
>
Preprocesses the model
for prediction.
Returns an updated model
.
Examples
iex> x = Nx.tensor([1, 4, 7, 9, 10, 11])
iex> y = Nx.tensor([1, 3, 6, 8, 9, 10])
iex> model = Scholar.Linear.IsotonicRegression.fit(x, y)
iex> Scholar.Linear.IsotonicRegression.preprocess(model)
%Scholar.Linear.IsotonicRegression{
x_min: Nx.tensor(
1.0
),
x_max: Nx.tensor(
11.0
),
x_thresholds: Nx.tensor(
[1.0, 4.0, 7.0, 9.0, 10.0, 11.0]
),
y_thresholds: Nx.tensor(
[1.0, 3.0, 6.0, 8.0, 9.0, 10.0]
),
increasing: Nx.u8(1),
cutoff_index: Nx.tensor(
5
),
preprocess: %Scholar.Interpolation.Linear{
coefficients: Nx.tensor(
[
[0.6666666865348816, 0.3333333134651184],
[1.0, -1.0],
[1.0, -1.0],
[1.0, -1.0],
[1.0, -1.0]
]
),
x: Nx.tensor(
[1.0, 4.0, 7.0, 9.0, 10.0, 11.0]
)
}
}