View Source Numerical Definitions (defn)
Section
The defn
macro and its siblings simplify the expression of mathematical formulas
containing tensors. Numerical definitions have two primary benefits
over classic Elixir functions.
They are tensor-aware. Nx replaces operators like
Kernel.-/2
with theDefn
counterpart, imported fromNx.Defn.Kernel
which usesNx
functions optimized for tensors when the operands are tensors, so the formulas we express can use tensors out of the box.defn
definitions allow for building computation graphs that combine the individual operations, which can then be used by just-in-time (JIT) compilers to emit highly specialized native code for the desired computation unit.
We don't have to do anything special to get access to
get tensor awareness beyond importing Nx.Defn
and writing
our code within a defn
block.
To use Nx in a Mix project or a notebook, we need to include
the :nx
dependency and import the Nx.Defn
module,
like this:
Mix.install([
{:nx, "~> 0.9"}
])
import Nx.Defn
Just as the Elixir language supports def
, defmacro
, and defp
,
Nx supports defn
. There are a few restrictions. It allows only
numerical arguments in the form of primitives or tensors as arguments
or return values, and supports only a subset of the language.
The subset of Elixir allowed within defn
is quite broad, though. We can
use macros, pipes, and even conditionals, so we're not giving up
much when you're declaring mathematical functions.
Additionally, despite these small concessions, defn
provides huge benefits.
Code inside a defn
block uses tensor-aware operators and types, so the math
beneath your functions has a better chance to shine through. Numerical
definitions can also run on accelerated numerical processors like GPUs and
TPUs. Here's an example numerical definition:
defmodule TensorMath do
import Nx.Defn
defn subtract(a, b) do
a - b
end
end
This module has a numerical definition that will be compiled.
If we wanted to specify a compiler for this module, we could add
a module attribute before the defn
clause. One of such compilers
is the EXLA compiler.
You would add the mix
dependency for EXLA and do this:
Nx.Defn.compile(&TensorMath.subtract/2, [Nx.template({3}, :f32), Nx.template({3}, :s32)], compiler: EXLA)
For a global approach, you can also set config :nx, default_defn_options: [compiler: EXLA]
in your application environment and call the TensorMath.subtract/2
function directly.
As an exercise, you can try adding a defn
to TensorMath
that accepts two tensors representing the lengths of sides of a
right triangle and uses the pythagorean theorem to return the
length of the hypotenuse.
Add your function directly to the previous Code cell.
deftransform
The defn
macro in Nx allows you to define functions that compile to efficient
numerical computations, but it comes with certain limitations — such as restrictions
on argument types, return values, and the subset of Elixir that it supports.
To overcome many of these limitations, Nx offers the deftransform
macro.
deftransform
lets you perform computations or execute code that isn't directly
supported by defn, and then incorporate those results back into your numerical
function. This separation lets you use standard Elixir features where necessary
while keeping your core numerical logic optimized.
It is important to highlight that all code inside the deftransform
call is being invoked
during the Nx compilation step, and thus can't depend on runtime values of the tensor inputs.
In the following example, we define a deftransform
function called
compute_tensor_from_list/1
that receives a list, which is not allowed
inside defn. Inside this function, we convert the list to a tensor
using Nx.tensor/1
, and then pass it to a defn function called double_tensor/1
,
which performs the actual numerical computation.
defmodule MyMath do
import Nx.Defn
# Numerical function that just multiplies the tensor by a scalar
defn scale_tensor(tensor) do
Nx.multiply(tensor, 10)
end
# This transform receives a 2D list, validates it, reshapes it,
# adds a new axis, and then passes it to a numerical function.
deftransform compute_from_2d_list(list_2d) do
# Validate that it's a proper matrix (all rows same length)
lengths = Enum.map(list_2d, &length/1)
if Enum.uniq(lengths) != [hd(lengths)] do
raise ArgumentError, "All inner lists must have the same length"
end
# Convert to tensor (e.g., shape {2, 3})
tensor = Nx.tensor(list_2d)
# Add a new axis at the beginning: {2, 3} -> {1, 2, 3}
reshaped = Nx.new_axis(tensor, 0)
# Pass to defn function
scale_tensor(reshaped)
end
end
matrix = [
[1, 2, 3],
[4, 5, 6]
]
result = MyMath.compute_from_2d_list(matrix)
This setup allows us to keep our defn code clean and focused only on tensor
operations, while using deftransform
to handle Elixir-native types and
preprocessing.
deftransform
s can also be called from within defn
functions, which
can be very useful for shape manipulation and validation:
defmodule MyOtherMath do
import Nx.Defn
# Numerical function that just multiplies the tensor by a scalar
defn reshape_and_scale_tensor(tensor, opts \\ []) do
tensor
|> validate_and_reshape(opts[:new_axes_count])
|> Nx.multiply(10)
end
deftransformp validate_and_reshape(tensor, new_axes_count) do
cond do
nil ->
tensor
is_integer(new_axes_count) ->
shape = List.to_tuple(Nx.shape(tensor) ++ List.duplicate(1, new_axes_count))
Nx.reshape(tensor, shape)
true ->
raise ArgumentError, "expected :new_axes_count to be an integer or nil, got: #{inspect(new_axes_count)}"
end
end
end