View Source Numerical Definitions (defn)
Section
The defn macro and its siblings simplify the expression of mathematical formulas
containing tensors. Numerical definitions have two primary benefits
over classic Elixir functions.
They are tensor-aware. Nx replaces operators like
Kernel.-/2with theDefncounterpart, imported fromNx.Defn.Kernelwhich usesNxfunctions optimized for tensors when the operands are tensors, so the formulas we express can use tensors out of the box.defndefinitions allow for building computation graphs that combine the individual operations, which can then be used by just-in-time (JIT) compilers to emit highly specialized native code for the desired computation unit.
We don't have to do anything special to get access to
get tensor awareness beyond importing Nx.Defn and writing
our code within a defn block.
To use Nx in a Mix project or a notebook, we need to include
the :nx dependency and import the Nx.Defn module,
like this:
Mix.install([
  {:nx, "~> 0.9"}
])import Nx.DefnJust as the Elixir language supports def, defmacro, and defp,
Nx supports defn. There are a few restrictions. It allows only
numerical arguments in the form of primitives or tensors as arguments
or return values, and supports only a subset of the language.
The subset of Elixir allowed within defn is quite broad, though. We can
use macros, pipes, and even conditionals, so we're not giving up
much when you're declaring mathematical functions.
Additionally, despite these small concessions, defn provides huge benefits.
Code inside a defn block uses tensor-aware operators and types, so the math
beneath your functions has a better chance to shine through. Numerical
definitions can also run on accelerated numerical processors like GPUs and
TPUs. Here's an example numerical definition:
defmodule TensorMath do
  import Nx.Defn
  defn subtract(a, b) do
    a - b
  end
endThis module has a numerical definition that will be compiled.
If we wanted to specify a compiler for this module, we could add
a module attribute before the defn clause. One of such compilers
is the EXLA compiler.
You would add the mix dependency for EXLA and do this:
Nx.Defn.compile(&TensorMath.subtract/2, [Nx.template({3}, :f32), Nx.template({3}, :s32)], compiler: EXLA)For a global approach, you can also set config :nx, default_defn_options: [compiler: EXLA]
in your application environment and call the TensorMath.subtract/2 function directly.
As an exercise, you can try adding a defn to TensorMath
that accepts two tensors representing the lengths of sides of a
right triangle and uses the pythagorean theorem to return the
length of the hypotenuse.
Add your function directly to the previous Code cell.
deftransform
The defn macro in Nx allows you to define functions that compile to efficient
numerical computations, but it comes with certain limitations — such as restrictions
on argument types, return values, and the subset of Elixir that it supports.
To overcome many of these limitations, Nx offers the deftransform macro.
deftransform lets you perform computations or execute code that isn't directly
supported by defn, and then incorporate those results back into your numerical
function. This separation lets you use standard Elixir features where necessary
while keeping your core numerical logic optimized.
It is important to highlight that all code inside the deftransform call is being invoked
during the Nx compilation step, and thus can't depend on runtime values of the tensor inputs.
In the following example, we define a deftransform function called
compute_tensor_from_list/1 that receives a list, which is not allowed
inside defn. Inside this function, we convert the list to a tensor
using Nx.tensor/1, and then pass it to a defn function called double_tensor/1,
which performs the actual numerical computation.
defmodule MyMath do
  import Nx.Defn
  # Numerical function that just multiplies the tensor by a scalar
  defn scale_tensor(tensor) do
    Nx.multiply(tensor, 10)
  end
  # This transform receives a 2D list, validates it, reshapes it,
  # adds a new axis, and then passes it to a numerical function.
  deftransform compute_from_2d_list(list_2d) do
    # Validate that it's a proper matrix (all rows same length)
    lengths = Enum.map(list_2d, &length/1)
    if Enum.uniq(lengths) != [hd(lengths)] do
      raise ArgumentError, "All inner lists must have the same length"
    end
    # Convert to tensor (e.g., shape {2, 3})
    tensor = Nx.tensor(list_2d)
    # Add a new axis at the beginning: {2, 3} -> {1, 2, 3}
    reshaped = Nx.new_axis(tensor, 0)
    # Pass to defn function
    scale_tensor(reshaped)
  end
endmatrix = [
  [1, 2, 3],
  [4, 5, 6]
]
result = MyMath.compute_from_2d_list(matrix)This setup allows us to keep our defn code clean and focused only on tensor
operations, while using deftransform to handle Elixir-native types and
preprocessing.
deftransforms can also be called from within defn functions, which
can be very useful for shape manipulation and validation:
defmodule MyOtherMath do
  import Nx.Defn
  # Numerical function that just multiplies the tensor by a scalar
  defn reshape_and_scale_tensor(tensor, opts \\ []) do
    tensor
    |> validate_and_reshape(opts[:new_axes_count])
    |> Nx.multiply(10)
  end
  deftransformp validate_and_reshape(tensor, new_axes_count) do
    cond do
      nil ->
        tensor
      is_integer(new_axes_count) ->
        shape = List.to_tuple(Nx.shape(tensor) ++ List.duplicate(1, new_axes_count))
        Nx.reshape(tensor, shape)
      true ->
        raise ArgumentError, "expected :new_axes_count to be an integer or nil, got: #{inspect(new_axes_count)}"
    end
  end
end