View Source Numerical Definitions (defn)

Section

The defn macro and its siblings simplify the expression of mathematical formulas containing tensors. Numerical definitions have two primary benefits over classic Elixir functions.

  • They are tensor-aware. Nx replaces operators like Kernel.-/2 with the Defn counterpart, imported from Nx.Defn.Kernel which uses Nx functions optimized for tensors when the operands are tensors, so the formulas we express can use tensors out of the box.

  • defn definitions allow for building computation graphs that combine the individual operations, which can then be used by just-in-time (JIT) compilers to emit highly specialized native code for the desired computation unit.

We don't have to do anything special to get access to get tensor awareness beyond importing Nx.Defn and writing our code within a defn block.

To use Nx in a Mix project or a notebook, we need to include the :nx dependency and import the Nx.Defn module, like this:

Mix.install([
  {:nx, "~> 0.9"}
])
import Nx.Defn

Just as the Elixir language supports def, defmacro, and defp, Nx supports defn. There are a few restrictions. It allows only numerical arguments in the form of primitives or tensors as arguments or return values, and supports only a subset of the language.

The subset of Elixir allowed within defn is quite broad, though. We can use macros, pipes, and even conditionals, so we're not giving up much when you're declaring mathematical functions.

Additionally, despite these small concessions, defn provides huge benefits. Code inside a defn block uses tensor-aware operators and types, so the math beneath your functions has a better chance to shine through. Numerical definitions can also run on accelerated numerical processors like GPUs and TPUs. Here's an example numerical definition:

defmodule TensorMath do
  import Nx.Defn

  defn subtract(a, b) do
    a - b
  end
end

This module has a numerical definition that will be compiled. If we wanted to specify a compiler for this module, we could add a module attribute before the defn clause. One of such compilers is the EXLA compiler. You would add the mix dependency for EXLA and do this:

Nx.Defn.compile(&TensorMath.subtract/2, [Nx.template({3}, :f32), Nx.template({3}, :s32)], compiler: EXLA)

For a global approach, you can also set config :nx, default_defn_options: [compiler: EXLA] in your application environment and call the TensorMath.subtract/2 function directly.

As an exercise, you can try adding a defn to TensorMath that accepts two tensors representing the lengths of sides of a right triangle and uses the pythagorean theorem to return the length of the hypotenuse. Add your function directly to the previous Code cell.

deftransform

The defn macro in Nx allows you to define functions that compile to efficient numerical computations, but it comes with certain limitations — such as restrictions on argument types, return values, and the subset of Elixir that it supports. To overcome many of these limitations, Nx offers the deftransform macro.

deftransform lets you perform computations or execute code that isn't directly supported by defn, and then incorporate those results back into your numerical function. This separation lets you use standard Elixir features where necessary while keeping your core numerical logic optimized.

It is important to highlight that all code inside the deftransform call is being invoked during the Nx compilation step, and thus can't depend on runtime values of the tensor inputs.

In the following example, we define a deftransform function called compute_tensor_from_list/1 that receives a list, which is not allowed inside defn. Inside this function, we convert the list to a tensor using Nx.tensor/1, and then pass it to a defn function called double_tensor/1, which performs the actual numerical computation.

defmodule MyMath do
  import Nx.Defn

  # Numerical function that just multiplies the tensor by a scalar
  defn scale_tensor(tensor) do
    Nx.multiply(tensor, 10)
  end

  # This transform receives a 2D list, validates it, reshapes it,
  # adds a new axis, and then passes it to a numerical function.
  deftransform compute_from_2d_list(list_2d) do
    # Validate that it's a proper matrix (all rows same length)
    lengths = Enum.map(list_2d, &length/1)
    if Enum.uniq(lengths) != [hd(lengths)] do
      raise ArgumentError, "All inner lists must have the same length"
    end

    # Convert to tensor (e.g., shape {2, 3})
    tensor = Nx.tensor(list_2d)

    # Add a new axis at the beginning: {2, 3} -> {1, 2, 3}
    reshaped = Nx.new_axis(tensor, 0)

    # Pass to defn function
    scale_tensor(reshaped)
  end
end
matrix = [
  [1, 2, 3],
  [4, 5, 6]
]

result = MyMath.compute_from_2d_list(matrix)

This setup allows us to keep our defn code clean and focused only on tensor operations, while using deftransform to handle Elixir-native types and preprocessing.

deftransforms can also be called from within defn functions, which can be very useful for shape manipulation and validation:

defmodule MyOtherMath do
  import Nx.Defn

  # Numerical function that just multiplies the tensor by a scalar
  defn reshape_and_scale_tensor(tensor, opts \\ []) do
    tensor
    |> validate_and_reshape(opts[:new_axes_count])
    |> Nx.multiply(10)
  end

  deftransformp validate_and_reshape(tensor, new_axes_count) do
    cond do
      nil ->
        tensor

      is_integer(new_axes_count) ->
        shape = List.to_tuple(Nx.shape(tensor) ++ List.duplicate(1, new_axes_count))
        Nx.reshape(tensor, shape)

      true ->
        raise ArgumentError, "expected :new_axes_count to be an integer or nil, got: #{inspect(new_axes_count)}"
    end
  end
end