View Source Accelerating Axon

Mix.install([
  {:axon, github: "elixir-nx/axon"},
  {:exla, "~> 0.3.0", github: "elixir-nx/nx", sparse: "exla"},
  {:torchx, github: "elixir-nx/nx", sparse: "torchx"},
  {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:benchee, github: "akoutmos/benchee", branch: :adding_table_support},
  {:kino_benchee, github: "livebook-dev/kino_benchee"},
  {:kino, "~> 0.7.0", override: true}
])
:ok

using-nx-compilers-in-axon

Using Nx Compilers in Axon

Axon is built entirely on top of Nx's numerical definitions defn. Functions declared with defn tell Nx to use just-in-time compilation to compile and execute the given numerical definition with an available Nx compiler. Numerical definitions enable acceleration on CPU/GPU/TPU via pluggable compilers. At the time of this writing, Nx has 2 officially supported compiler/backends on top of the default BinaryBackend:

  1. EXLA - Acceleration via Google's XLA project
  2. TorchX - Bindings to LibTorch

By default, Nx and Axon run all computations using the BinaryBackend which is a pure Elixir implementation of various numerical routines. The BinaryBackend is guaranteed to run wherever an Elixir installation runs; however, it is very slow. Due to the computational expense of neural networks, you should basically never use the BinaryBackend and instead opt for one of the available accelerated libraries.

There are several ways to make use of Nx compilers from within Axon. First, create a simple model for benchmarking purposes:

model =
  Axon.input("data")
  |> Axon.dense(32)
  |> Axon.relu()
  |> Axon.dense(1)
  |> Axon.softmax()
#Axon<
  inputs: %{"data" => nil}
  outputs: "softmax_0"
  nodes: 5
>

By default, Axon will respect the default defn compilation options. You can set compilation options globally or per-process:

# Sets the global compilation options
Nx.Defn.global_default_options(compiler: EXLA)
# Sets the process-level compilation options
Nx.Defn.default_options(compiler: EXLA)
[compiler: EXLA]

When you call Axon.build/2, Axon automatically marks your initialization and forward functions as JIT compiled functions. When you invoke them, they will compile a specialized version of the function using your default compiler options:

inputs = Nx.random_uniform({2, 128})
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)

10:34:02.503 [info]  XLA service 0x7fbd5468c170 initialized for platform Host (this does not guarantee that XLA will be used). Devices:

10:34:02.785 [info]    StreamExecutor device (0): Host, Default Version
#Nx.Tensor<
  f32[2][1]
  EXLA.Backend<host:0, 0.184501844.1168769032.259095>
  [
    [1.0],
    [1.0]
  ]
>

Notice that the inspected tensor indicates the computation has been dispatched to EXLA and the tensor's data points to an EXLA buffer.

If you feel like setting the global or process-level compilation options is too intrusive, you can opt for more explicit behavior in a few ways. First, you can specify the JIT compiler when you build the model:

# Set back to defaults
Nx.Defn.global_default_options([])
Nx.Defn.default_options([])
[compiler: EXLA]
{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  EXLA.Backend<host:0, 0.184501844.1168769032.259101>
  [
    [1.0],
    [1.0]
  ]
>

You can also instead JIT compile functions explicitly via the Nx.Defn.jit or compiler-specific JIT APIs. This is useful when running benchmarks against various backends:

{init_fn, predict_fn} = Axon.build(model)

# These will both JIT compile with EXLA
exla_init_fn = Nx.Defn.jit(init_fn, compiler: EXLA)
exla_predict_fn = EXLA.jit(predict_fn)
#Function<136.40088443/2 in Nx.Defn.wrap_arity/2>
Benchee.run(
  %{
    "elixir init" => fn -> init_fn.(inputs, %{}) end,
    "exla init" => fn -> exla_init_fn.(inputs, %{}) end
  },
  time: 10,
  memory_time: 5,
  warmup: 2
)
Warning: the benchmark elixir init is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark exla init is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: Linux
CPU Information: Intel(R) Core(TM) i7-7600U CPU @ 2.80GHz
Number of Available Cores: 4
Available memory: 24.95 GB
Elixir 1.13.4
Erlang 25.0.4

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 5 s
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 34 s

Benchmarking elixir init ...
Benchmarking exla init ...

Name                  ips        average  deviation         median         99th %
exla init          3.79 K        0.26 ms   ±100.40%        0.24 ms        0.97 ms
elixir init        0.52 K        1.91 ms    ±35.03%        1.72 ms        3.72 ms

Comparison:
exla init          3.79 K
elixir init        0.52 K - 7.25x slower +1.65 ms

Memory usage statistics:

Name           Memory usage
exla init           9.80 KB
elixir init       644.63 KB - 65.80x memory usage +634.83 KB

**All measurements for memory usage were the same**
Benchee.run(
  %{
    "elixir predict" => fn -> predict_fn.(params, inputs) end,
    "exla predict" => fn -> exla_predict_fn.(params, inputs) end
  },
  time: 10,
  memory_time: 5,
  warmup: 2
)
Warning: the benchmark elixir predict is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark exla predict is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: Linux
CPU Information: Intel(R) Core(TM) i7-7600U CPU @ 2.80GHz
Number of Available Cores: 4
Available memory: 24.95 GB
Elixir 1.13.4
Erlang 25.0.4

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 5 s
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 34 s

Benchmarking elixir predict ...
Benchmarking exla predict ...

Name                     ips        average  deviation         median         99th %
exla predict          2.32 K        0.43 ms   ±147.05%        0.34 ms        1.61 ms
elixir predict        0.28 K        3.53 ms    ±42.21%        3.11 ms        7.26 ms

Comparison:
exla predict          2.32 K
elixir predict        0.28 K - 8.20x slower +3.10 ms

Memory usage statistics:

Name              Memory usage
exla predict          10.95 KB
elixir predict        91.09 KB - 8.32x memory usage +80.14 KB

**All measurements for memory usage were the same**

Notice how calls to EXLA variants are significantly faster than their Elixir counterparts. These speedups become more pronounced with more complex models and workflows.

It's important to note that in order to use a given library as an Nx compiler, it must implement the Nx compilation behaviour. For example, you cannot invoke Torchx as an Nx compiler because it does not support JIT compilation at this time.

using-nx-backends-in-axon

Using Nx Backends in Axon

In addition to JIT-compilation, Axon also supports the usage of Nx backends. Nx backends are slightly different than Nx compilers in the sense that they do not fuse calls within numerical definitions. Backends are more eager, sacrificing a bit of performance for convenience. Torchx and EXLA both support running via backends.

Again, Axon will respect the global and process-level Nx backend configuration options. You can set the default backend using:

# Global default backend
Nx.global_default_backend(Torchx.Backend)
# Process default backend
Nx.default_backend(Torchx.Backend)
{Nx.BinaryBackend, []}

Now when you invoke model functions, it will run them with the given backend:

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  Torchx.Backend(cpu)
  [
    [1.0],
    [1.0]
  ]
>
# Global default backend
Nx.global_default_backend(EXLA.Backend)
# Process default backend
Nx.default_backend(EXLA.Backend)
{Torchx.Backend, []}
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  EXLA.Backend<host:0, 0.184501844.1169293320.110725>
  [
    [1.0],
    [1.0]
  ]
>

Unlike with JIT-compilation, you must set the backend at the top-level in order to invoke it. You should be careful using multiple backends in the same project as attempting to mix tensors between backends may result in strange performance bugs or errors.

With most larger models, using a JIT compiler will be more performant than using a backend.

a-note-on-cpus-gpus-tpus

A Note on CPUs/GPUs/TPUs

While Nx mostly tries to standardize behavior across compilers and backends, some behaviors are backend-specific. For example, the API for choosing an acceleration platform (e.g. CUDA/ROCm/TPU) is backend-specific. You should refer to your chosen compiler or backend's documentation for information on targeting various accelerators. Typically, you only need to change a few configuration options and your code will run as-is on a chosen accelerator.