# `Nx.Defn`
[🔗](https://github.com/elixir-nx/nx/blob/v0.12.0/nx/lib/nx/defn.ex#L1)

Numerical functions.

A numerical function is a subset of Elixir tailored for
numerical computations. For example, the following function:

    defmodule MyModule do
      import Nx.Defn

      defn softmax(t) do
        Nx.exp(t) / Nx.sum(Nx.exp(t))
      end
    end

will work with scalars, vector, matrices, and n-dimensional
tensors. Depending on your compiler of choice, the code can even
be JIT-compiled and run either on the CPU or GPU.

To support these features, `defn` is a subset of Elixir. It
replaces Elixir's `Kernel` by `Nx.Defn.Kernel`. `Nx.Defn.Kernel`
provides tensor-aware operators, such as `+`, `-`, etc, while
also preserving many high-level constructs known to Elixir
developers, such as pipe operator, aliases, conditionals,
pattern-matching, the access syntax, and more:

For example, the code above can also be written as:

    defmodule MyModule do
      import Nx.Defn

      defn softmax(t) do
        t
        |> Nx.exp()
        |> then(& &1 / Nx.sum(&1))
      end
    end

Please consult `Nx.Defn.Kernel` for a complete reference.

Some of the functions in this module may also be used within
`defn`.

## Operators

`defn` attempts to keep as close to the Elixir semantics as
possible but that's not achievable. For example, mathematical
and bitwise operators (`+`, `-`, `&&&`, `<<<`, etc.) in Elixir
work on numbers, which means mapping them to tensors is
straight-forward and they largely preserve the same semantics,
except they are now multi-dimensional.

On the other hand, the logical operators `and`, `or`, and `not`
work with booleans in Elixir (`true` and `false`), which map
to `0` and `1` in `defn`.

Therefore, when working with logical operators inside `defn`,
`0` is considered `false` and all other numbers are considered
`true`, which is represented as the number `1`. For example, in
`defn`, `0 and 1` as well as `0 and 2` return `0`, while
`1 and 1` or `1 and -1` will return `1`.

The same semantics apply to conditional expressions inside `defn`,
such as `if`, `while`, etc.

## JIT compilers

The power of `Nx.Defn` is given by its compilers. The default
compiler is `Nx.Defn.Evaluator`, which evaluates the code.
You can use `jit/3` to compile a function on the fly using a
different compiler, such as `EXLA`:

    fun = Nx.Defn.jit(&MyModule.softmax/1, compiler: EXLA)
    fun.(my_tensor)

The above will return an anonymous function that optimizes,
compiles, and run `softmax` on the fly on the CPU (or the GPU)
if available. EXLA, in particular, also exports a `EXLA.jit/2`
function for convenience.

`defn` functions are compiled when they are invoked, based on
the type and shapes of the tensors given as arguments.
Therefore compilation may be quite time consuming on the first
invocation. The compilation is then cached based on the tensors
shapes and types. Calling the same function with a tensor of
different values but same shape and type means no recompilation
is performed.

For those interested in writing custom compilers, see `Nx.Defn.Compiler`.

## Invoking custom Elixir code

Inside `defn` you can only call other `defn` functions and
the functions in the `Nx` module. However, it is possible
to use transforms, defined with either `deftransform` or
`deftransformp` to invoke any Elixir code.

You can call code which was defined with `deftransform` from another module:

    defmodule MyRemoteModule do
      import Nx.Defn

      deftransform remote_elixir_code(value) do
        IO.inspect(value)
      end
    end

    defn add_and_mult(a, b, c) do
      res = a * b + c
      MyRemoteModule.remote_elixir_code(res)
    end

You can also define and call a private transform defined through `deftransformp`:

    defn add_and_mult(a, b, c) do
      res = a * b + c
      custom_elixir_code(res)
    end

    deftransformp custom_elixir_code(value), do: IO.inspect(value)

The only difference between using `deftransform` and `deftransformp`
is whether you want to expose and share the code with other modules,
just like `def` and `defp`.

Transforms are useful to manipulate tensor expressions or
Elixir data structures without the constraints of `defn`.

## Inputs and outputs types

`Nx` and `defn` expect the arguments to be numbers, tensors,
or one composite data type that implements `Nx.LazyContainer`.
Tuples and maps implement `Nx.LazyContainer` by default.
As previously described, `defn` are cached based on the shape,
type, and names of the input tensors, but not their values.

`defn` also accepts two special arguments: functions (or tuples
of functions) and lists (most commonly as keyword lists). Those
values are passed as is to numerical definitions and cached as
a whole. For this reason, you must never capture tensors in
functions or pass tensors in keyword lists.

When numbers are given as arguments, they are always immediately
converted to tensors on invocation. If you want to keep numbers
as is or if you want to pass any other value to numerical definitions,
they must be given as keyword lists.

### Default arguments

`defn` functions support default arguments. They are typically used
as options. For example, imagine you want to create a function named
zeros, which returns a tensor of zeroes with a given type and shape.
It could be implemented like this:

    defn zeros(opts \\ []) do
      opts = keyword!(opts, type: {:f, 32}, shape: {})
      Nx.broadcast(Nx.tensor(0, type: opts[:type]), opts[:shape])
    end

The function above accepts `opts` which are then validated and given
default values via the `keyword!/2` function. Note that while it is
possible to access options via the `Access` syntax, such as `opts[:shape]`,
it is not possible to directly call functions in the `Keyword` module
inside `defn`. To freely manipulate any Elixir value inside `defn`,
you have to use transforms, as described in the "Invoking custom Elixir
code" section.

> **Important!** When it comes to JIT compilation, each different set of
> options (as well as anonymous functions) will lead to a different
> compilation of the numerical function.
>
> Furthermore, if tensors are given through keyword lists, they won't
> be cached effectively. Tensors in `defn` are cached based on their shape
> and type, not their value, but this is not true if the tensor is given
> via a default argument or captured by an anonymous function. For this
> reason, it is **extremely discouraged to pass tensors through anonymous
> functions and default arguments**.

### Working with maps and structs

While `Nx` supports maps in `defn`, you must be careful if your numerical
definitions are receiving maps and returning maps. For example, imagine
this code:

    defn update_a(map) do
      %{map | a: Nx.add(map.a, 1)}
    end

The following code increments the value under the key `:a`
by 1. However, because the function receives the whole map and
returns the whole map, it means if the map has 120 keys, the
whole map will be copied to the CPU/GPU, and then brought back.

However, if you do this instead:

    defn update_a(map) do
      Nx.add(map.a, 1)
    end

And then update the map on Elixir, outside of `defn`:

    %{map | a: update_a(map)}

`Nx` will only send the parts of the map that matters.

## Recursion and loops

Given numerical definition first build a representation of
your code, it is not possible to write recursive (nor tail
recursive) code inside `defn`. Instead, one must use
`Nx.Defn.Kernel.while/4`.

# `compile`

Compiles the given anonymous function with the given tensor shapes.

While `jit/2` compiles a function just-in time based on the
input shapes, this function precompiles the given anonymous
function based on the input shapes. This can be beneficial for
large numerical definitions, where the cache mechanism in `jit/2`
may take milliseconds.

For example, take the following definition:

    defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t))

You can jit and then apply it as:

    fun = Nx.Defn.compile(&softmax/1, [Nx.template({3}, {:s, 32})], compiler: EXLA)
    fun.(Nx.tensor([1, 2, 3]))

You can also pass a mixture of templates and options when
compiling a function. In such cases, you must only pass
the inputs when invoking the compiled function, as the options
will already be embedded in its compiled value:

    fun = Nx.Defn.compile(&Nx.sum/2, [Nx.template({2, 2}, {:s, 32}), [axes: [1]]])
    fun.(Nx.iota({2, 2}))

If the input tensors do not match the shape of the tensors
given on compilation, it will raise.

## Options

  * `:compiler` - the compiler for the JIT compilation

  * `:hooks` - a map of hooks to execute. See `Nx.Defn.Kernel.hook/3`

# `debug_expr`

Wraps an anonymous function to return its underlying defn expression.

> #### Warning {: .warning}
>
> This function must be invoked for debugging purposes only.

## Options

  * `:hooks` - a map of hooks to execute. See `Nx.Defn.Kernel.hook/3`

# `debug_expr_apply`

Invokes the anonymous function to return its underlying defn expression.

> #### Warning {: .warning}
>
> This function must be invoked for debugging purposes only.

It accepts the same options as `debug_expr/2`.

# `default_options`

Gets the default options for the current process.

# `default_options`

Sets the default options for `defn` in the current process.

The options defined here apply to all future invocations of
`defn` done by the current process. It also applies to calls
to the `jit/3` and `stream/3` functions in this module.

The default options are stored only in the process dictionary
and override any global options. This means if you start a
separate process, such as `Task`, the default options must be
set on the new process too.

The function returns the values that were previously set as default
options.

This function must be used only for scripting and testing.

## Examples

    iex> Nx.Defn.default_options(compiler: EXLA, client: :cuda)
    iex> Nx.Defn.default_options()
    [compiler: EXLA, client: :cuda]

# `defn`
*macro* 

Defines a public numerical function.

# `defnp`
*macro* 

Defines a private numerical function.

Private numerical functions are always inlined by
their callers at compilation time. This happens to
all local function calls within `defn`.

# `deftransform`
*macro* 

Can be used to define bodiless clauses for multi-clause transforms.

See also: `deftransform/2`

## Examples

    deftransform foo(bar, baz \ 1)
    deftransform foo(bar, 1), do: bar
    deftransform foo(bar, baz), do: bar + baz

# `deftransform`
*macro* 

Defines a transform that executes the given `fun` with `arg`
when building `defn` expressions.

## Example

Take the following defn expression:

    defn tanh_power(a, b) do
      Nx.tanh(a) + Nx.pow(b, 2)
    end

Let's see a trivial example, which is to use `IO.inspect/1` to
print a tensor expression at definition time:

    defn tanh_power(a, b) do
      Nx.tanh(a) + Nx.pow(b, 2) |> my_inspect()
    end

    deftransformp my_inspect(expr), do: IO.inspect(expr)

Or:

    defn tanh_power(a, b) do
      res = Nx.tanh(a) + Nx.pow(b, 2)
      my_inspect(res)
      res
    end

When invoked in both cases, it will print the expression being built
by `defn`:

    #Nx.Defn.Expr<
      parameter a
      parameter c
      b = tanh [ a ] ()
      d = pow [ c, 2 ] ()
      e = add [ b, d ] ()
    >

Although, for convenience, you might use `print_expr/2` instead.

# `deftransformp`
*macro* 

Private function version for `deftransform/1`

# `deftransformp`
*macro* 

Private function version for `deftransform/2`

# `global_default_options`

Sets the default options globally.

The options defined here apply to all future invocations of
`defn`. It also applies to calls to the `jit/3` and `stream/3`
functions in this module.

You must avoid calling this function at runtime and mostly for
testing purposes. You may also set in your test environment using
configuration:

    config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]

The function returns the values that were previously set as global
default options.

# `grad`

Receives an anonymous function and returns a new anonymous function
that returns the gradient of the input function when invoked.

## Examples

    iex> fun = Nx.Defn.grad(fn x -> Nx.sin(x) end)
    iex> fun.(Nx.tensor(0))
    #Nx.Tensor<
      f32
      1.0
    >

# `grad`

Computes the gradient of the given `var` on `fun`.

The result of the `grad` function must be a scalar tensor.
If a non-scalar tensor is given, it is assumed the additional
dimensions are batch dimensions.

## Examples

    defn tanh_grad(t) do
      grad(t, &Nx.tanh/1)
    end

To differentiate on multiple vars, pass a tuple as first argument:

    defn tanh_power_grad(a, b) do
      grad({a, b}, fn {a, b} -> Nx.tanh(a) + Nx.pow(b, 2) end)
    end

`var_or_vars` can be any `Nx.Container` with one or multiple
tensors.

# `jit`

Wraps an anonymous function with just-in-time compilation.

Once invoked, the wrapped anonymous function will perform just
in time compilation with the configured compiler. For example,
take the following definition:

    defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t))

You can jit and then apply it as:

    fun = Nx.Defn.jit(&softmax/1, compiler: EXLA)
    fun.(Nx.tensor([1, 2, 3]))

## Options

  * `:compiler` - the compiler for the JIT compilation

  * `:hooks` - a map of hooks to execute. See `Nx.Defn.Kernel.hook/3`

  * `:on_conflict` - what to do if a JIT compilation is already in place.
    It may be `:raise` (the default), `:force` (forces a new JIT compilation),
    or `:reuse` (reuses the exiting JIT compilation). It is not recommended
    to set the `:compiler` option when reusing.

# `jit_apply`

Invokes the anonymous function with just-in-time compilation.

This function is equivalent to calling `jit/2` and then applying
the given arguments to the anonymous function.

For example, take the following definition:

    defn softmax(t), do: Nx.exp(t) / Nx.sum(Nx.exp(t))

You can `jit_apply/3` it as:

    Nx.Defn.jit_apply(&softmax/1, [Nx.tensor([1, 2, 3])], compiler: EXLA)

It accepts the same options as `jit/2`.

# `shard_jit`

# `shard_jit_apply`

# `to_backend`

Returns a backend corresponding to the compiler options.

The backend matches the backend used for outputs from computations
defined by the given compiler.

# `value_and_grad`

Receives an anonymous function and returns a new anonymous function
that returns the value and gradient of the input function when invoked.

## Examples

    iex> fun = Nx.Defn.value_and_grad(fn x -> Nx.sin(x) end)
    iex> {value, grad} = fun.(Nx.tensor(0))
    iex> value
    #Nx.Tensor<
      f32
      0.0
    >
    iex> grad
    #Nx.Tensor<
      f32
      1.0
    >

# `value_and_grad`

Computes the value and gradient of the given `var` on `fun`
with an optional data transformation.

It returns a tuple with the value and the gradient.

## Examples

    defn tanh_grad(t) do
      value_and_grad(t, &Nx.tanh/1)
    end

To differentiate on multiple vars, pass a tuple as first argument:

    defn tanh_power_grad(a, b) do
      value_and_grad({a, b}, fn {a, b} -> Nx.tanh(a) + Nx.pow(b, 2) end)
    end

`var_or_vars` can be any `Nx.Container` with one or multiple
tensors.

`transform` allows you to transform the expression before the gradient is
calculated. This enables optimizations that reuse parts of expressions. As
an example, consider the following objective function:

    defn objective(predict_fn, loss_fn, params, inputs, targets) do
      preds = predict_fn.(params, inputs)
      loss = loss_fn.(preds, targets)
      {preds, loss}
    end

You can compute the gradient with respect to just the loss function by applying
a transform:

    {{preds, loss}, gradient} = value_and_grad(params, &objective(predict_fn, loss_fn, &1, inputs, targets), &elem(&1, 1))

`preds` can be re-used to compute other metrics such as accuracy, absolute error,
etc. without having to do another forward pass.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
