View Source Compiled Decision Trees Benchmark
Mix.install([
{:scidata, "~> 0.1"},
{:exgboost, "~> 0.4"},
{:mockingjay, github: "acalejos/mockingjay"},
{:nx, "~> 0.5", override: true},
{:exla, "~> 0.5"},
{:scholar, "~> 0.2"},
{:benchee, "~> 1.0"}
])
Setup Dataset
{x, y} = Scidata.Iris.download()
data = Enum.zip(x, y) |> Enum.shuffle()
{train, test} = Enum.split(data, ceil(length(data) * 0.8))
{x_train, y_train} = Enum.unzip(train)
{x_test, y_test} = Enum.unzip(test)
x_train = Nx.tensor(x_train)
y_train = Nx.tensor(y_train)
x_test = Nx.tensor(x_test)
y_test = Nx.tensor(y_test)
Gather Model / Prediction Functions
EXGBoost.compile/1
will convert your trained Booster
model into a set of tensor operations which can then be run on any of the Nx
backends.
# Get Baseline Model (XGBoost C API)
model = EXGBoost.train(x_train, y_train, num_class: 3, objective: :multi_softprob)
# Get Compiled Models w/ Binary Backend
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(Nx.BinaryBackend)
gemm_predict = EXGBoost.compile(model, strategy: :gemm)
gemm_jit_exla = EXLA.jit(gemm_predict)
tree_trav_predict = EXGBoost.compile(model, strategy: :tree_traversal)
tree_trav_jit_exla = EXLA.jit(tree_trav_predict)
ptt_predict = EXGBoost.compile(model, strategy: :perfect_tree_traversal)
ptt_jit_exla = EXLA.jit(ptt_predict)
# Get Compiled Models w/ EXLA Backend
Nx.Defn.default_options(compiler: EXLA)
Nx.default_backend(EXLA.Backend)
gemm_exla = EXGBoost.compile(model, strategy: :gemm)
tree_trav_exla = EXGBoost.compile(model, strategy: :tree_traversal)
ptt_exla = EXGBoost.compile(model, strategy: :perfect_tree_traversal)
funcs = %{
"Base" => fn x -> EXGBoost.predict(model, x) end,
"Compiled -- GEMM Strategy -- Binary Backend" => fn x -> gemm_predict.(x) end,
"Compiled -- Tree Traversal Strategy -- Binary Backend" => fn x -> tree_trav_predict.(x) end,
"Compiled -- Perfect Tree Traversal Strategy -- Binary Backend" => fn x -> ptt_predict.(x) end,
"Compiled -- GEMM Strategy -- EXLA Backend" => fn x -> gemm_exla.(x) end,
"Compiled -- Tree Traversal Strategy -- EXLA Backend" => fn x -> tree_trav_exla.(x) end,
"Compiled -- Perfect Tree Traversal Strategy -- EXLA Backend" => fn x -> ptt_exla.(x) end,
"Compiled -- GEMM Strategy -- EXLA Backend (JIT)" => fn x -> gemm_jit_exla.(x) end,
"Compiled -- Tree Traversal Strategy -- EXLA Backend (JIT)" => fn x ->
tree_trav_jit_exla.(x)
end,
"Compiled -- Perfect Tree Traversal Strategy -- EXLA Backend (JIT)" => fn x ->
ptt_jit_exla.(x)
end
}
Run Time Benchmarks
benches = Map.new(funcs, fn {k, v} -> {k, v.(x_train)} end)
Benchee.run(benches,
time: 10,
memory_time: 2,
warmup: 5
)
Compare Accuracies
Nx.Defn.default_options(compiler: Nx.Defn.Evaluator)
Nx.default_backend(Nx.BinaryBackend)
accuracies =
Enum.reduce(funcs, %{}, fn {name, pred_fn}, acc ->
accuracy =
pred_fn.(x_test)
|> Nx.argmax(axis: -1)
|> then(&Scholar.Metrics.Classification.accuracy(y_test, &1))
|> Nx.to_number()
Map.put(acc, name, accuracy)
end)