# Numerical Definitions (defn)

## Section

The `defn` macro and its siblings simplify the expression of mathematical formulas
containing tensors. Numerical definitions have two primary benefits
over classic Elixir functions.

- They are _tensor-aware_. Nx replaces operators like `Kernel.-/2`
  with the `Defn` counterpart, imported from `Nx.Defn.Kernel` which uses `Nx` functions
  optimized for tensors when the operands are tensors, so the formulas we express can use
  tensors out of the box.

- `defn` definitions allow for building computation graphs that combine the
  individual operations, which can then be used by just-in-time (JIT) compilers
  to emit highly specialized native code for the desired computation unit.

We don't have to do anything special to get access to
get tensor awareness beyond importing `Nx.Defn` and writing
our code within a `defn` block.

To use Nx in a Mix project or a notebook, we need to include
the `:nx` dependency and import the `Nx.Defn` module,
like this:

```elixir
Mix.install([
  {:nx, "~> 0.9"}
])
```

```elixir
import Nx.Defn
```

Just as the Elixir language supports `def`, `defmacro`, and `defp`,
Nx supports `defn`. There are a few restrictions. It allows only
numerical arguments in the form of primitives or tensors as arguments
or return values, and supports only a subset of the language.

The subset of Elixir allowed within `defn` is quite broad, though. We can
use macros, pipes, and even conditionals, so we're not giving up
much when you're declaring mathematical functions.

Additionally, despite these small concessions, `defn` provides huge benefits.
Code inside a `defn` block uses tensor-aware operators and types, so the math
beneath your functions has a better chance to shine through. Numerical
definitions can also run on accelerated numerical processors like GPUs and
TPUs. Here's an example numerical definition:

```elixir
defmodule TensorMath do
  import Nx.Defn

  defn subtract(a, b) do
    a - b
  end
end
```

This module has a numerical definition that will be compiled.
If we wanted to specify a compiler for this module, we could add
a module attribute before the `defn` clause. One of such compilers
is [the EXLA compiler](https://github.com/elixir-nx/nx/tree/main/exla).
You would add the `mix` dependency for EXLA and do this:

<!-- livebook:{"force_markdown":true} -->

```elixir
Nx.Defn.compile(&TensorMath.subtract/2, [Nx.template({3}, :f32), Nx.template({3}, :s32)], compiler: EXLA)
```

For a global approach, you can also set `config :nx, default_defn_options: [compiler: EXLA]`
in your application environment and call the `TensorMath.subtract/2` function directly.

As an exercise, you can try adding a `defn` to `TensorMath`
that accepts two tensors representing the lengths of sides of a
right triangle and uses the pythagorean theorem to return the
[length of the hypotenuse](https://www.mathsisfun.com/pythagoras.html).
Add your function directly to the previous Code cell.

## deftransform

The `defn` macro in Nx allows you to define functions that compile to efficient
numerical computations, but it comes with certain limitations — such as restrictions
on argument types, return values, and the subset of Elixir that it supports.
To overcome many of these limitations, Nx offers the `deftransform` macro.

`deftransform` lets you perform computations or execute code that isn't directly
supported by defn, and then incorporate those results back into your numerical
function. This separation lets you use standard Elixir features where necessary
while keeping your core numerical logic optimized.

It is important to highlight that all code inside the `deftransform` call is being invoked
during the Nx compilation step, and thus can't depend on runtime values of the tensor inputs.

In the following example, we define a `deftransform` function called
`compute_tensor_from_list/1` that receives a list, which is not allowed
inside defn. Inside this function, we convert the list to a tensor
using `Nx.tensor/1`, and then pass it to a defn function called `double_tensor/1`,
which performs the actual numerical computation.

```elixir
defmodule MyMath do
  import Nx.Defn

  # Numerical function that just multiplies the tensor by a scalar
  defn scale_tensor(tensor) do
    Nx.multiply(tensor, 10)
  end

  # This transform receives a 2D list, validates it, reshapes it,
  # adds a new axis, and then passes it to a numerical function.
  deftransform compute_from_2d_list(list_2d) do
    # Validate that it's a proper matrix (all rows same length)
    lengths = Enum.map(list_2d, &length/1)
    if Enum.uniq(lengths) != [hd(lengths)] do
      raise ArgumentError, "All inner lists must have the same length"
    end

    # Convert to tensor (e.g., shape {2, 3})
    tensor = Nx.tensor(list_2d)

    # Add a new axis at the beginning: {2, 3} -> {1, 2, 3}
    reshaped = Nx.new_axis(tensor, 0)

    # Pass to defn function
    scale_tensor(reshaped)
  end
end
```

```elixir
matrix = [
  [1, 2, 3],
  [4, 5, 6]
]

result = MyMath.compute_from_2d_list(matrix)
```

This setup allows us to keep our defn code clean and focused only on tensor
operations, while using `deftransform` to handle Elixir-native types and
preprocessing.

`deftransform`s can also be called from within `defn` functions, which
can be very useful for shape manipulation and validation:

```elixir
defmodule MyOtherMath do
  import Nx.Defn

  # Numerical function that just multiplies the tensor by a scalar
  defn reshape_and_scale_tensor(tensor, opts \\ []) do
    tensor
    |> validate_and_reshape(opts[:new_axes_count])
    |> Nx.multiply(10)
  end

  deftransformp validate_and_reshape(tensor, new_axes_count) do
    cond do
      nil ->
        tensor

      is_integer(new_axes_count) ->
        shape = List.to_tuple(Nx.shape(tensor) ++ List.duplicate(1, new_axes_count))
        Nx.reshape(tensor, shape)

      true ->
        raise ArgumentError, "expected :new_axes_count to be an integer or nil, got: #{inspect(new_axes_count)}"
    end
  end
end
```
