View Source Nx.Defn (Nx v0.4.0)

Numerical functions.

A numerical function is a subset of Elixir tailored for numerical computations. For example, the following function:

defmodule MyModule do
  import Nx.Defn

  defn softmax(t) do
    Nx.exp(t) / Nx.sum(Nx.exp(t))
  end
end

will work with scalars, vector, matrices, and n-dimensional tensors. Depending on your compiler of choice, the code can even be JIT-compiled and run either on the CPU or GPU.

To support these features, defn is a subset of Elixir. It replaces Elixir's Kernel by Nx.Defn.Kernel. Nx.Defn.Kernel provides tensor-aware operators, such as +, -, etc, while also preserving many high-level constructs known to Elixir developers, such as pipe operator, aliases, conditionals, pattern-matching, the access syntax, and more:

For example, the code above can also be written as:

defmodule MyModule do
  import Nx.Defn

  defn softmax(t) do
    t
    |> Nx.exp(t)
    |> then(& &1 / Nx.sum(&1))
  end
end

Please consult Nx.Defn.Kernel for a complete reference.

operators

Operators

defn attempts to keep as close to the Elixir semantics as possible but that's not achievable. For example, mathematical and bitwise operators (+, -, &&&, <<<, etc.) in Elixir work on numbers, which means mapping them to tensors is straight-forward and they largely preserve the same semantics, except they are now multi-dimensional.

On the other hand, the logical operators and, or, and not work with booleans in Elixir (true and false), which map to 0 and 1 in defn.

Therefore, when working with logical operators inside defn, 0 is considered false and all other numbers are considered true, which is represented as the number 1. For example, in defn, 0 and 1 as well as 0 and 2 return 0, while 1 and 1 or 1 and -1 will return 1.

The same semantics apply to conditional expressions inside defn, such as if, while, etc.

jit-compilers

JIT compilers

The power of Nx.Defn is given by its compilers. The default compiler is Nx.Defn.Evaluator, which evalutes the code. You can use jit/3 to compile a function on the fly using a different compiler, such as EXLA:

fun = Nx.Defn.jit(&MyModule.softmax/1, compiler: EXLA)
fun.(my_tensor)

The above will return an anonymous function that optimizes, compiles, and run softmax on the fly on the CPU (or the GPU) if available.

You can also change the default compiler for all numerical definitions (defn) by setting the default options. This can be done in your config/*.exs files as follows:

config :nx, :default_defn_options, compiler: EXLA

Now calling MyModule.softmax(my_tensor) will use EXLA even without wrapping it in jit/2.

However, note that compilation may be quite time consuming on the first invocation, that's why it is often preferred to use the compiler: EXLA option when calling the functions in this module instead. EXLA, in particular, also exports a EXLA.jit/2 function for convenience.

defn functions are compiled when they are invoked, based on the type and shapes of the tensors given as arguments. The compilation is then cached based on the tensors shapes and types. Calling the same function with a tensor of different values but same shape and type means no recompilation is performed.

For those interested in writing custom compilers, see Nx.Defn.Compiler.

invoking-custom-elixir-code

Invoking custom Elixir code

Inside defn you can only call other defn functions and the functions in the Nx module. However, it is possible to use transforms, defined with either deftransform or deftransformp to invoke any Elixir code.

You can call code which was defined with deftransform from another module:

defmodule MyRemoteModule do
  import Nx.Defn

  deftransform remote_elixir_code(value) do
    IO.inspect(value)
  end
end

defn add_and_mult(a, b, c) do
  res = a * b + c
  MyRemoteModule.remote_elixir_code(res)
end

You can also define and call a private transform defined through deftransformp:

defn add_and_mult(a, b, c) do
  res = a * b + c
  custom_elixir_code(res)
end

deftransformp custom_elixir_code(value), do: IO.inspect(value)

For example, the two code snippets invoke IO.inspect/1, which is not a defn function, with the value of res. This is useful as it allows developers to transform defn code to optimize, add new properties, and so on.

The only difference between using deftransform and deftransformp is wether you want to expose and share the code with other modules, just like def and defp.

Transforms can also be used to manipulate Elixir data structures, such as options. defn expects all inputs to be tensors, with the exception of a default argument (declared with \\) which will be treated as options.

For example, imagine you want to support options where the :axis key is required. While you can't invoke Keyword directly, you can do it via a transform:

defn sum_axis(t, opts \\ []) do
  opts = keyword!(opts, [:axis])
  axis = get_axis(opts)
  Nx.sum(t, axes: [axis])
end

deftransformp get_axis(opts), do: Keyword.fetch!(opts, :axis)

inputs-and-outputs-types

Inputs and outputs types

Nx and defn expect the arguments to be numbers, tensors, or one composite data type that implements Nx.LazyContainer. Tuples and maps implement Nx.LazyContainer by default. As previously described, defn are cached based on the shape, type, and names of the input tensors, but not their values.

defn also accepts two special arguments: functions (or tuples of functions) and lists (most commonly as keyword lists). Those values are passed as is to numerical definitions and cached as a whole. For this reason, you must never capture tensors in functions or pass tensors in keyword lists.

When numbers are given as arguments, they are always immediately converted to tensors on invocation. If you want to keep numbers as is or if you want to pass any other value to numerical definitions, they must be given as keyword lists.

default-arguments

Default arguments

defn functions support default arguments. They are typically used as options. For example, imagine you want to create a function named zeros, which returns a tensor of zeroes with a given type and shape. It could be implemented like this:

defn zeros(opts \\ []) do
  opts = keyword!(opts, type: {:f, 32}, shape: {})
  Nx.broadcast(Nx.tensor(0, type: opts[:type]), opts[:shape])
end

The function above accepts opts which are then validated and given default values via the keyword!/2 function. Note that while it is possible to access options via the Access syntax, such as opts[:shape], it is not possible to directly call functions in the Keyword module inside defn. To freely manipulate any Elixir value inside defn, you have to use transforms, as described in the "Invoking custom Elixir code" section.

Important! When it comes to JIT compilation, each different set of options (as well as anonymous functions) will lead to a different compilation of the numerical function.

Furthermore, if tensors are given through keyword lists, they won't be cached effectively. Tensors in defn are cached based on their shape and type, not their value, but this is not true if the tensor is given via a default argument or captured by an anonymous function. For this reason, it is extremely discouraged to pass tensors through anonymous functions and default arguments.

working-with-maps-and-structs

Working with maps and structs

While Nx supports maps in defn, you must be careful if your numerical definitions are receiving maps and returning maps. For example, imagine this code:

defn update_a(map) do
  %{map | a: Nx.add(map.a, 1)}
end

The following code increments the value under the key :a by 1. However, because the function receives the whole map and returns the whole map, it means if the map has 120 keys, the whole map will be copied to the CPU/GPU, and then brought back.

However, if you do this instead:

defn update_a(map) do
  Nx.add(map.a, 1)
end

And then update the map on Elixir, outside of defn:

%{map | a: update_a(map)}

Nx will only send the parts of the map that matters.

Link to this section Summary

Functions

Compiles the given anonymous function with the given tensor shapes.

Wraps an anonymous function to return its underlying defn expression.

Invokes the anonymous function to return its underlying defn expression.

Gets the default options for the current process.

Sets the default options for defn in the current process.

Defines a public numerical function.

Defines a private numerical function.

Can be used to define bodiless clauses for multi-clause transforms.

Defines a transform that executes the given fun with arg when building defn expressions.

Private function version for deftransform/1

Private function version for deftransform/2

Sets the default options globally.

Receives an anonymous function and returns a new anonymous function that returns the gradient of the input function when invoked.

Computes the gradient of the given var on fun.

Wraps an anonymous function with just-in-time compilation.

Invokes the anonymous function with just-in-time compilation.

Starts streaming the given anonymous function with just-in-time compilation.

Receives an anonymous function and returns a new anonymous function that returns the value and gradient of the input function when invoked.

Computes the value and gradient of the given var on fun with an optional data transformation.

Link to this section Functions

Link to this function

compile(fun, template_args, opts \\ [])

View Source

Compiles the given anonymous function with the given tensor shapes.

While jit/2 compiles a function just-in time based on the input shapes, this function precompiles the given anonymous function based on the input shapes. This can be beneficial for large numerical definitions, where the cache mechanism in jit/2 may take miliseconds.

For example, take the following definition:

defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t))

You can jit and then apply it as:

fun = Nx.Defn.compile(&softmax/1, [Nx.template({3}, {:s, 64})], compiler: EXLA)
fun.(Nx.tensor([1, 2, 3]))

You can also pass a mixture of templates and options when compiling a function. In such cases, you must only pass the inputs when invoking the compiled function, as the options will already be embedded in its compiled value:

fun = Nx.Defn.compile(&Nx.sum/2, [Nx.template({2, 2}, {:s, 64}), [axes: [1]]])
fun.(Nx.iota({2, 2}))

If the input tensors do not match the shape of the tensors given on compilation, it will raise.

options

Options

  • :compiler - the compiler for the JIT compilation

  • :hooks - a map of hooks to execute. See Nx.Defn.Kernel.hook/3

Link to this function

debug_expr(fun, opts \\ [])

View Source

Wraps an anonymous function to return its underlying defn expression.

Warning

This function must be invoked for debugging purposes only.

options

Options

Link to this function

debug_expr_apply(fun, args, opts \\ [])

View Source

Invokes the anonymous function to return its underlying defn expression.

Warning

This function must be invoked for debugging purposes only.

It accepts the same options as debug_expr/2.

Gets the default options for the current process.

Link to this function

default_options(options)

View Source

Sets the default options for defn in the current process.

The options defined here apply to all future invocations of defn done by the current process. It also applies to calls to the jit/3 and stream/3 functions in this module.

The default options are stored only in the process dictionary and override any global options. This means if you start a separate process, such as Task, the default options must be set on the new process too.

This function is mostly used for scripting and testing. In your applications, you typically set the default options in your config files:

  config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]
Link to this macro

defn(call, list)

View Source (macro)

Defines a public numerical function.

Link to this macro

defnp(call, list)

View Source (macro)

Defines a private numerical function.

Private numerical functions are always inlined by their callers at compilation time. This happens to all local function calls within defn.

Link to this macro

deftransform(call)

View Source (macro)

Can be used to define bodiless clauses for multi-clause transforms.

See also: deftransform/2

examples

Examples

deftransform foo(bar, baz \ 1)
deftransform foo(bar, 1), do: bar
deftransform foo(bar, baz), do: bar + baz
Link to this macro

deftransform(call, list)

View Source (macro)

Defines a transform that executes the given fun with arg when building defn expressions.

example

Example

Take the following defn expression:

defn tanh_power(a, b) do
  Nx.tanh(a) + Nx.power(b, 2)
end

Let's see a trivial example, which is to use IO.inspect/1 to print a tensor expression at definition time:

defn tanh_power(a, b) do
  Nx.tanh(a) + Nx.power(b, 2) |> my_inspect()
end

deftransformp my_inspect(expr), do: IO.inspect(expr)

Or:

defn tanh_power(a, b) do
  res = Nx.tanh(a) + Nx.power(b, 2)
  my_inspect(res)
  res
end

When invoked in both cases, it will print the expression being built by defn:

#Nx.Defn.Expr<
  parameter a
  parameter c
  b = tanh [ a ] ()
  d = power [ c, 2 ] ()
  e = add [ b, d ] ()
>

Although, for convenience, you might use print_expr/2 instead.

Link to this macro

deftransformp(call)

View Source (macro)

Private function version for deftransform/1

Link to this macro

deftransformp(call, list)

View Source (macro)

Private function version for deftransform/2

Link to this function

global_default_options(options)

View Source

Sets the default options globally.

The options defined here apply to all future invocations of defn. It also applies to calls to the jit/3 and stream/3 functions in this module.

You must avoid calling this function at runtime. It is mostly useful during scripts or code notebooks to set a default. If you need to configure a global default options in your applications, you can do so in your config/*.exs files:

config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]

Receives an anonymous function and returns a new anonymous function that returns the gradient of the input function when invoked.

examples

Examples

iex> fun = Nx.Defn.grad(fn x -> Nx.sin(x) end)
iex> fun.(Nx.tensor(0))
#Nx.Tensor<
  f32
  1.0
>

Computes the gradient of the given var on fun.

The result of the grad function must be a scalar tensor. If a non-scalar tensor is given, it is assumed the additional dimensions are batch dimensions.

examples

Examples

defn tanh_grad(t) do
  grad(t, &Nx.tanh/&1)
end

To differentiate on multiple vars, pass a tuple as first argument:

defn tanh_power_grad(a, b) do
  grad({a, b}, fn {a, b} -> Nx.tanh(a) + Nx.power(b, 2) end)
end

var_or_vars can be any Nx.Container with one or multiple tensors.

Wraps an anonymous function with just-in-time compilation.

Once invoked, the wrapped anonymous function will perform just in time compilation with the configured compiler. For example, take the following definition:

defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t))

You can jit and then apply it as:

fun = Nx.Defn.jit(&softmax/1, compiler: EXLA)
fun.(Nx.tensor([1, 2, 3]))

options

Options

  • :compiler - the compiler for the JIT compilation

  • :hooks - a map of hooks to execute. See Nx.Defn.Kernel.hook/3

  • :on_conflict - what to do if a JIT compilation is already in place. It may be :raise (the default), :force (forces a new JIT compilation), or :reuse (reuses the exiting JIT compilation). It is not recommended to set the :compiler option when reusing.

This function is deprecated. Use jit/2 instead.
Link to this function

jit_apply(fun, args, opts \\ [])

View Source

Invokes the anonymous function with just-in-time compilation.

This function is equivalent to calling jit/2 and then applying the given arguments to the anonymous function.

For example, take the following definition:

defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t))

You can jit_apply/3 it as:

Nx.Defn.jit_apply(&softmax/1, [Nx.tensor([1, 2, 3])], compiler: EXLA)

It accepts the same options as jit/2.

Link to this function

jit_or_apply(fun, args, opts \\ [])

View Source
This function is deprecated. Use jit/2 or jit_apply/3 with the :on_conflict option.
Link to this function

stream(fun, args, opts \\ [])

View Source

Starts streaming the given anonymous function with just-in-time compilation.

At least two arguments are expected:

  1. The first argument is a tensor template of the data to be streamed in

  2. The second argument is a tensor with the stream initial state

The streaming function must return a two element tuple, the first element is the data to be sent and the second is the accumulator.

For each streamed chunk, you must call Nx.Stream.send/2 and Nx.Stream.recv/1. You don't need to call recv immediately after send, but doing so can be a useful mechanism to provide backpressure. Once all chunks are sent, you must use Nx.Stream.done/1 to receive the accumulated result. Let's see an example:

defmodule Streamed do
  import Nx.Defn

  defn sum(tensor, acc) do
    {acc, tensor + acc}
  end
end

Now let's invoke it:

stream = Nx.Defn.stream(&Streamed.sum/2, [Nx.template({}, {:s, 64}), 0])

for i <- 1..5 do
  Nx.Stream.send(stream, i)
  IO.inspect {:chunk, Nx.Stream.recv(stream)}
end

IO.inspect {:result, Nx.Stream.done(stream)}

It will print:

{:chunk, 0}
{:chunk, 1}
{:chunk, 2}
{:chunk, 3}
{:chunk, 4}
{:result, 5}

options

Options

Receives an anonymous function and returns a new anonymous function that returns the value and gradient of the input function when invoked.

examples

Examples

iex> fun = Nx.Defn.value_and_grad(fn x -> Nx.sin(x) end)
iex> {value, grad} = fun.(Nx.tensor(0))
iex> value
#Nx.Tensor<
  f32
  0.0
>
iex> grad
#Nx.Tensor<
  f32
  1.0
>
Link to this function

value_and_grad(var_or_vars, fun, transform \\ & &1)

View Source

Computes the value and gradient of the given var on fun with an optional data transformation.

It returns a tuple with the value and the gradient.

examples

Examples

defn tanh_grad(t) do
  value_and_grad(t, &Nx.tanh/&1)
end

To differentiate on multiple vars, pass a tuple as first argument:

defn tanh_power_grad(a, b) do
  value_and_grad({a, b}, fn {a, b} -> Nx.tanh(a) + Nx.power(b, 2) end)
end

var_or_vars can be any Nx.Container with one or multiple tensors.

transform allows you to transform the expression before the gradient is calculated. This enables optimizations that reuse parts of expressions. As an example, consider the following objective function:

defn objective(predict_fn, loss_fn, params, inputs, targets) do
  preds = predict_fn.(params, inputs)
  loss = loss_fn.(preds, targets)
  {preds, loss}
end

You can compute the gradient with respect to just the loss function by applying a transform:

{{preds, loss}, gradient} = value_and_grad(params, &objective(predict_fn, loss_fn, &1, inputs, targets), &elem(&1, 1))

preds can be re-used to compute other metrics such as accuracy, absolute error, etc. without having to do another forward pass.