View Source Accelerating Axon

Mix.install([
  {:axon, ">= 0.5.0"},
  {:exla, ">= 0.5.0"},
  {:torchx, ">= 0.5.0"},
  {:benchee, "~> 1.1"},
  {:kino, ">= 0.9.0", override: true}
])
:ok

Using Nx Backends in Axon

Nx provides two mechanisms for accelerating your neural networks: backends and compilers. Before we learn how to effectively use them, first let's create a simple model for benchmarking purposes:

model =
  Axon.input("data")
  |> Axon.dense(32)
  |> Axon.relu()
  |> Axon.dense(1)
  |> Axon.softmax()
#Axon<
  inputs: %{"data" => nil}
  outputs: "softmax_0"
  nodes: 5
>

Backends are where your tensors (your neural network inputs and parameters) are located. By default, Nx and Axon run all computations using the Nx.BinaryBackend which is a pure Elixir implementation of various numerical routines. The Nx.BinaryBackend is guaranteed to run wherever an Elixir installation runs; however, it is very slow. Due to the computational expense of neural networks, you should basically never use the Nx.BinaryBackend and instead opt for one of the available accelerated libraries. At the time of writing, Nx officially supports two of them:

  1. EXLA - Acceleration via Google's XLA project
  2. TorchX - Bindings to LibTorch

Axon will respect the global and process-level Nx backend configuration. Compilers are covered more in-depth in the second half of this example. You can set the default backend using the following APIs:

# Sets the global compilation options (for all Elixir processes)
Nx.global_default_backend(Torchx.Backend)
# OR
Nx.global_default_backend(EXLA.Backend)

# Sets the process-level compilation options (current process only)
Nx.default_backend(Torchx.Backend)
# OR
Nx.default_backend(EXLA.Backend)

Now all tensors and operations on them will run on the configured backend:

{inputs, _next_key} =
  Nx.Random.key(9999)
  |> Nx.Random.uniform(shape: {2, 128})

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  EXLA.Backend<cuda:0, 0.3278685746.4275961901.179470>
  f32[2][1]
  [
    [1.0],
    [1.0]
  ]
>

As you swap backends above, you will get tensors allocated on different backends as results. You should be careful using multiple backends in the same project as attempting to mix tensors between backends may result in strange performance bugs or errors, as Nx will require you to explicitly convert between backends.

With most larger models, using a compiler will bring more performance benefits in addition to the backend.

Using Nx Compilers in Axon

Axon is built entirely on top of Nx's numerical definitions defn. Functions declared with defn tell Nx to use just-in-time compilation to compile and execute the given numerical definition with an available Nx compiler. Numerical definitions enable acceleration on CPU/GPU/TPU via pluggable compilers. At the time of this writing, only EXLA supports a compiler in addition to its backend.

When you call Axon.build/2, Axon can automatically mark your initialization and forward functions as JIT compiled functions. First let's make sure we are using the EXLA backend:

Nx.default_backend(EXLA.Backend)

And now let's build another model, this time passing the EXLA compiler as an option:

{inputs, _next_key} =
  Nx.Random.key(9999)
  |> Nx.Random.uniform(shape: {2, 128})

{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)

15:39:26.463 [info] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

15:39:26.473 [info] XLA service 0x7f3488329030 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:

15:39:26.473 [info]   StreamExecutor device (0): NVIDIA GeForce RTX 3050 Ti Laptop GPU, Compute Capability 8.6

15:39:26.473 [info] Using BFC allocator.

15:39:26.473 [info] XLA backend allocating 3605004288 bytes on device 0 for BFCAllocator.

15:39:28.272 [info] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
#Nx.Tensor<
  f32[2][1]
  EXLA.Backend<cuda:0, 0.3278685746.4275699756.253533>
  [
    [1.0],
    [1.0]
  ]
>

You can also instead JIT compile functions explicitly via the Nx.Defn.jit or compiler-specific JIT APIs. This is useful when running benchmarks against various backends:

{init_fn, predict_fn} = Axon.build(model)

# These will both JIT compile with EXLA
exla_init_fn = Nx.Defn.jit(init_fn, compiler: EXLA)
exla_predict_fn = EXLA.jit(predict_fn)
#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>
Benchee.run(
  %{
    "elixir init" => fn -> init_fn.(inputs, %{}) end,
    "exla init" => fn -> exla_init_fn.(inputs, %{}) end
  },
  time: 10,
  memory_time: 5,
  warmup: 2
)
Warning: the benchmark elixir init is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark exla init is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: Linux
CPU Information: Intel(R) Core(TM) i7-7600U CPU @ 2.80GHz
Number of Available Cores: 4
Available memory: 24.95 GB
Elixir 1.13.4
Erlang 25.0.4

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 5 s
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 34 s

Benchmarking elixir init ...
Benchmarking exla init ...

Name                  ips        average  deviation         median         99th %
exla init          3.79 K        0.26 ms   ±100.40%        0.24 ms        0.97 ms
elixir init        0.52 K        1.91 ms    ±35.03%        1.72 ms        3.72 ms

Comparison:
exla init          3.79 K
elixir init        0.52 K - 7.25x slower +1.65 ms

Memory usage statistics:

Name           Memory usage
exla init           9.80 KB
elixir init       644.63 KB - 65.80x memory usage +634.83 KB

**All measurements for memory usage were the same**
Benchee.run(
  %{
    "elixir predict" => fn -> predict_fn.(params, inputs) end,
    "exla predict" => fn -> exla_predict_fn.(params, inputs) end
  },
  time: 10,
  memory_time: 5,
  warmup: 2
)
Warning: the benchmark elixir predict is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Warning: the benchmark exla predict is using an evaluated function.
  Evaluated functions perform slower than compiled functions.
  You can move the Benchee caller to a function in a module and invoke `Mod.fun()` instead.
  Alternatively, you can move the benchmark into a benchmark.exs file and run mix run benchmark.exs

Operating System: Linux
CPU Information: Intel(R) Core(TM) i7-7600U CPU @ 2.80GHz
Number of Available Cores: 4
Available memory: 24.95 GB
Elixir 1.13.4
Erlang 25.0.4

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 5 s
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 34 s

Benchmarking elixir predict ...
Benchmarking exla predict ...

Name                     ips        average  deviation         median         99th %
exla predict          2.32 K        0.43 ms   ±147.05%        0.34 ms        1.61 ms
elixir predict        0.28 K        3.53 ms    ±42.21%        3.11 ms        7.26 ms

Comparison:
exla predict          2.32 K
elixir predict        0.28 K - 8.20x slower +3.10 ms

Memory usage statistics:

Name              Memory usage
exla predict          10.95 KB
elixir predict        91.09 KB - 8.32x memory usage +80.14 KB

**All measurements for memory usage were the same**

Notice how calls to EXLA variants are significantly faster. These speedups become more pronounced with more complex models and workflows.

It's important to note that in order to use a given library as an Nx compiler, it must implement the Nx compilation behaviour. For example, you cannot invoke Torchx as an Nx compiler because it does not support JIT compilation at this time.

A Note on CPUs/GPUs/TPUs

While Nx mostly tries to standardize behavior across compilers and backends, some behaviors are backend-specific. For example, the API for choosing an acceleration platform (e.g. CUDA/ROCm/TPU) is backend-specific. You should refer to your chosen compiler or backend's documentation for information on targeting various accelerators. Typically, you only need to change a few configuration options and your code will run as-is on a chosen accelerator.