# View Source Nx.Defn.Kernel (Nx v0.2.0)

All imported functionality available inside `defn`

blocks.

# Link to this section Summary

## Functions

Element-wise inequality operation.

Element-wise bitwise AND operation.

Element-wise multiplication operator.

Element-wise multiplication operator.

Element-wise unary plus operator.

Element-wise addition operator.

Element-wise unary plus operator.

Element-wise subtraction operator.

Builds a range.

Builds a range with step.

Element-wise division operator.

Element-wise less than operation.

Element-wise left shift operation.

Element-wise less-equal operation.

Element-wise equality operation.

Element-wise greater than operation.

Element-wise greater-equal operation.

Element-wise right shift operation.

Reads a module attribute at compilation time.

Defines an alias, as in `Kernel.SpecialForms.alias/2`

.

Element-wise logical AND operation.

Asserts the `tensor`

has a certain `shape`

.

Asserts the `tensor`

has a certain `shape`

pattern.

Attaches a token to an expression. See `hook/3`

.

Evaluates the expression corresponding to the first clause that evaluates to a truthy value.

Creates a token for hooks. See `hook/3`

.

Defines a custom gradient for the given expression.

Gets the element at the zero-based index in tuple.

Shortcut for `hook/3`

.

Defines a hook.

Shortcut for `hook_token/4`

.

Defines a hook with an existing token. See `hook/3`

.

Provides if/else expressions.

Imports functions and macros into the current scope,
as in `Kernel.SpecialForms.import/2`

.

Inspects the given expression to the terminal.

Inspects the value at runtime to the terminal.

Ensures the first argument is a `keyword`

with the given
keys and default values.

Element-wise maximum operation.

Element-wise minimum operation.

Element-wise logical NOT operation.

Element-wise logical OR operation.

Element-wise remainder operation.

Requires a module in order to use its macros, as in `Kernel.SpecialForms.require/2`

.

Rewrites the types of `expr`

recursively according to `opts`

Stops computing the gradient for the given expression.

Pipes `value`

to the given `fun`

and returns the `value`

itself.

Pipes `value`

into the given `fun`

.

Defines a transform that executes the given `fun`

with `arg`

when building `defn`

expressions.

Defines a `while`

loop.

Pipes the argument on the left to the function call on the right.

Element-wise bitwise OR operation.

Element-wise bitwise not operation.

# Link to this section Functions

Element-wise inequality operation.

It delegates to `Nx.not_equal/2`

.

##
examples

Examples

```
defn check_inequality(a, b) do
a != b
end
```

Element-wise bitwise AND operation.

Only integer tensors are supported.
It delegates to `Nx.bitwise_and/2`

(supports broadcasting).

##
examples

Examples

```
defn and_or(a, b) do
{a &&& b, a ||| b}
end
```

Element-wise multiplication operator.

It delegates to `Nx.power/2`

(supports broadcasting).

##
examples

Examples

```
defn power(a, b) do
a ** b
end
```

Element-wise multiplication operator.

It delegates to `Nx.multiply/2`

(supports broadcasting).

##
examples

Examples

```
defn multiply(a, b) do
a * b
end
```

Element-wise unary plus operator.

Simply returns the given argument.

##
examples

Examples

```
defn plus_and_minus(a) do
{+a, -a}
end
```

Element-wise addition operator.

It delegates to `Nx.add/2`

(supports broadcasting).

##
examples

Examples

```
defn add(a, b) do
a + b
end
```

Element-wise unary plus operator.

It delegates to `Nx.negate/1`

.

##
examples

Examples

```
defn plus_and_minus(a) do
{+a, -a}
end
```

Element-wise subtraction operator.

It delegates to `Nx.subtract/2`

(supports broadcasting).

##
examples

Examples

```
defn subtract(a, b) do
a - b
end
```

Builds a range.

Ranges are inclusive and both sides must be integers.

The step of the range is computed based on the first and last values of the range.

##
examples

Examples

```
iex> t = Nx.tensor([1, 2, 3])
iex> t[1..2]
#Nx.Tensor<
s64[2]
[2, 3]
>
```

Builds a range with step.

Ranges are inclusive and both sides must be integers.

##
examples

Examples

```
iex> t = Nx.tensor([1, 2, 3])
iex> t[1..2//1]
#Nx.Tensor<
s64[2]
[2, 3]
>
```

Element-wise division operator.

It delegates to `Nx.divide/2`

(supports broadcasting).

##
examples

Examples

```
defn divide(a, b) do
a / b
end
```

Element-wise less than operation.

It delegates to `Nx.less/2`

.

##
examples

Examples

```
defn check_less_than(a, b) do
a < b
end
```

Element-wise left shift operation.

Only integer tensors are supported.
It delegates to `Nx.left_shift/2`

(supports broadcasting).

##
examples

Examples

```
defn shift_left_and_right(a, b) do
{a <<< b, a >>> b}
end
```

Element-wise less-equal operation.

It delegates to `Nx.less_equal/2`

.

##
examples

Examples

```
defn check_less_equal(a, b) do
a <= b
end
```

Element-wise equality operation.

It delegates to `Nx.equal/2`

.

##
examples

Examples

```
defn check_equality(a, b) do
a == b
end
```

Element-wise greater than operation.

It delegates to `Nx.greater/2`

.

##
examples

Examples

```
defn check_greater_than(a, b) do
a > b
end
```

Element-wise greater-equal operation.

It delegates to `Nx.greater_equal/2`

.

##
examples

Examples

```
defn check_greater_equal(a, b) do
a >= b
end
```

Element-wise right shift operation.

Only integer tensors are supported.
It delegates to `Nx.right_shift/2`

(supports broadcasting).

##
examples

Examples

```
defn shift_left_and_right(a, b) do
{a <<< b, a >>> b}
end
```

Reads a module attribute at compilation time.

It is useful to inject code constants into `defn`

.
It delegates to `Kernel.@/1`

.

##
examples

Examples

```
@two_per_two Nx.tensor([[1, 2], [3, 4]])
defn add_2x2_attribute(t), do: t + @two_per_two
```

Defines an alias, as in `Kernel.SpecialForms.alias/2`

.

An alias allows you to refer to a module using its aliased name. For example:

```
defn some_fun(t) do
alias Math.Helpers, as: MH
MH.fft(t)
end
```

If the `:as`

option is not given, the alias defaults to
the last part of the given alias. For example,

`alias Math.Helpers`

is equivalent to:

`alias Math.Helpers, as: Helpers`

Finally, note that aliases define outside of a function also apply to the function, as they have lexical scope:

```
alias Math.Helpers, as: MH
defn some_fun(t) do
MH.fft(t)
end
```

Element-wise logical AND operation.

Zero is considered false, all other numbers are considered true.

It delegates to `Nx.logical_and/2`

(supports broadcasting).

##
examples

Examples

```
defn and_or(a, b) do
{a and b, a or b}
end
```

Asserts the `tensor`

has a certain `shape`

.

If it succeeds, it returns the given tensor. Raises an error otherwise.

##
examples

Examples

To assert the tensor is a scalar, you can pass the empty tuple `shape`

:

```
iex> assert_shape Nx.tensor(13), {}
#Nx.Tensor<
s64
13
>
```

If the shapes do not match, an error is raised:

```
iex> assert_shape Nx.tensor([1, 2, 3]), {}
** (ArgumentError) expected tensor to be a scalar, got tensor with shape {3}
iex> assert_shape Nx.tensor([1, 2, 3]), {4}
** (ArgumentError) expected tensor to have shape {4}, got tensor with shape {3}
```

If you want to assert on the rank or shape patterns, use
`assert_shape_pattern/2`

instead.

Asserts the `tensor`

has a certain `shape`

pattern.

If it succeeds, it returns the given tensor. Raises an error otherwise.

##
examples

Examples

Opposite to `assert_shape/2`

, where the given shape is a value,
`assert_shape_pattern`

allows the shape to be any Elixir pattern.
We can use this to match on ranks:

```
iex> assert_shape_pattern Nx.tensor([[1, 2], [3, 4]]), {_, _}
#Nx.Tensor<
s64[2][2]
[
[1, 2],
[3, 4]
]
>
iex> assert_shape_pattern Nx.tensor([1, 2, 3]), {_, _}
** (ArgumentError) expected tensor to match shape {_, _}, got tensor with shape {3}
```

Or even use variables to assert on properties such as square matrices:

```
iex> assert_shape_pattern Nx.tensor([[1, 2], [3, 4]]), {x, x}
#Nx.Tensor<
s64[2][2]
[
[1, 2],
[3, 4]
]
>
iex> assert_shape_pattern Nx.tensor([1, 2, 3]), {x, x}
** (ArgumentError) expected tensor to match shape {x, x}, got tensor with shape {3}
```

You can also use guards to specify tall matrices and so forth:

```
iex> assert_shape_pattern Nx.tensor([[1], [2]]), {x, y} when x > y
#Nx.Tensor<
s64[2][1]
[
[1],
[2]
]
>
iex> assert_shape_pattern Nx.tensor([1, 2]), {x, y} when x > y
** (ArgumentError) expected tensor to match shape {x, y} when x > y, got tensor with shape {2}
```

Attaches a token to an expression. See `hook/3`

.

Evaluates the expression corresponding to the first clause that evaluates to a truthy value.

It has the format of:

```
cond do
condition1 ->
expr1
condition2 ->
expr2
:otherwise ->
expr3
end
```

The conditions must be a scalar. Zero is considered false, any other number is considered true.

All clauses are normalized to the same type and are broadcast
to the same shape. The last condition must always evaluate to
an atom, typically `:otherwise`

.

##
examples

Examples

```
cond do
Nx.all(Nx.greater(a, 0)) -> b *
Nx.all(Nx.less(a, 0)) -> b + c
true -> b - c
end
```

Creates a token for hooks. See `hook/3`

.

Defines a custom gradient for the given expression.

It expects a `fun`

to compute the gradient. The function
will be called with the expression itself and the current
gradient. It must return a list of arguments and their
updated gradient to continue applying `grad`

on.

##
examples

Examples

For example, if the gradient of `cos(t)`

were to be
implemented by hand:

```
def cos(t) do
custom_grad(Nx.cos(t), fn _ans, g ->
[{t, -g * Nx.sin(t)}]
end)
end
```

Gets the element at the zero-based index in tuple.

It raises ArgumentError when index is negative or it is out of range of the tuple elements.

##
examples

Examples

```
iex> tuple = {1, 2, 3}
iex> elem(tuple, 0)
1
```

Shortcut for `hook/3`

.

Defines a hook.

Hooks are a mechanism to execute an anonymous function for side-effects with runtime tensor values.

Let's see an example:

```
defmodule Hooks do
import Nx.Defn
defn add_and_mult(a, b) do
add = hook(a + b, fn tensor -> IO.inspect({:add, tensor}) end)
mult = hook(a * b, fn tensor -> IO.inspect({:mult, tensor}) end)
{add, mult}
end
end
```

The `defn`

above defines two hooks, one is called with the
value of `a + b`

and another with `a * b`

. Once you invoke
the function above, you should see this printed:

```
Hooks.add_and_mult(2, 3)
{:add, #Nx.Tensor<
s64
5
>}
{:mult, #Nx.Tensor<
s64
6
>}
```

In other words, the `hook`

function accepts a tensor
expression as argument and it will invoke a custom
function with a tensor value at runtime. `hook`

returns
the result of the given expression. The expression can
be any tensor or a `Nx.Container`

.

Note **you must return the result of the hook call**.
For example, the code below won't inspect the

`:add`

tuple, because the hook is not returned from `defn`

:```
defn add_and_mult(a, b) do
_add = hook(a + b, fn tensor -> IO.inspect({:add, tensor}) end)
mult = hook(a * b, fn tensor -> IO.inspect({:mult, tensor}) end)
mult
end
```

We will learn how to hook into a value that is not part of the result in the "Hooks and tokens" section.

##
named-hooks

Named hooks

It is possible to give names to the hooks. This allows them
to be defined or overridden by calling `Nx.Defn.jit/3`

or
`Nx.Defn.stream/3`

. Let's see an example:

```
defmodule Hooks do
import Nx.Defn
defn add_and_mult(a, b) do
add = hook(a + b, :hooks_add)
mult = hook(a * b, :hooks_mult)
{add, mult}
end
end
```

Now you can pass the hook as argument as follows:

```
hooks = %{
hooks_add: fn tensor ->
IO.inspect {:add, tensor}
end
}
args = [Nx.tensor(2), Nx.tensor(3)]
Nx.Defn.jit(&Hooks.add_and_mult/2, args, hooks: hooks)
```

Important!We recommend to prefix your hook names by the name of your project to avoid conflicts.

If a named hook is not given, compilers can optimize that away and not transfer the tensor from the device in the first place.

You can also mix named hooks with callbacks:

```
defn add_and_mult(a, b) do
add = hook(a + b, :hooks_add, fn tensor -> IO.inspect({:add, tensor}) end)
mult = hook(a * b, :hooks_mult, fn tensor -> IO.inspect({:mult, tensor}) end)
{add, mult}
end
```

If a hook with the same name is given to `Nx.Defn.jit/3`

or `Nx.Defn.stream/3`

, then it will override the default
callback.

##
hooks-and-tokens

Hooks and tokens

So far, we have always returned the result of the `hook`

call. However, what happens if the values we want to
hook are not part of the return value, such as below?

```
defn add_and_mult(a, b) do
_add = hook(a + b, :hooks_add, &IO.inspect({:add, &1}))
mult = hook(a * b, :hooks_mult, &IO.inspect({:mult, &1}))
mult
end
```

In such cases, you must use tokens. Tokens are used to create an ordering over hooks, ensuring hooks execute in a certain sequence:

```
defn add_and_mult(a, b) do
token = create_token()
{token, _add} = hook_token(token, a + b, :hooks_add, &IO.inspect({:add, &1}))
{token, mult} = hook_token(token, a * b, :hooks_mult, &IO.inspect({:mult, &1}))
attach_token(token, mult)
end
```

The example above creates a token and uses `hook_token/4`

to create hooks attached to their respective tokens. By using a token,
we guarantee that those hooks will be invoked in the order
in which they were defined. Then, at the end of the function,
we attach the token (and its associated hooks) to the result `mult`

.

In fact, the `hook/3`

function is implemented roughly like this:

```
def hook(tensor_expr, name, function) do
{token, result} = hook_token(create_token(), tensor_expr, name, function)
attach_token(token, result)
end
```

Note you must attach the token at the end, otherwise the hooks will be "lost", as if they were not defined. This also applies to conditionals and loops. The token must be attached within the branch they are used. For example, this won't work:

```
token = create_token()
{token, result} =
if Nx.any(value) do
hook_token(token, some_value)
else
hook_token(token, another_value)
end
attach_token(result)
```

Instead, you must write:

```
token = create_token()
if Nx.any(value) do
{token, result} = hook_token(token, some_value)
attach_token(token, result)
else
{token, result} = hook_token(token, another_value)
attach_token(token, result)
end
```

Shortcut for `hook_token/4`

.

Defines a hook with an existing token. See `hook/3`

.

Provides if/else expressions.

The first argument must be a scalar. Zero is considered false, any other number is considered true.

The second argument is a keyword list with `do`

and `else`

blocks. The sides are broadcast to return the same shape
and normalized to return the same type.

##
examples

Examples

```
if Nx.any(Nx.equal(t, 0)) do
0.0
else
1 / t
end
```

In case else is not given, it is assumed to be 0 with the
same as the do clause. If you want to nest multiple conditionals,
see `cond/1`

instead.

Imports functions and macros into the current scope,
as in `Kernel.SpecialForms.import/2`

.

Imports are typically discouraged in favor of `alias/2`

.

##
examples

Examples

```
defn some_fun(t) do
import Math.Helpers
fft(t)
end
```

Inspects the given expression to the terminal.

It returns the given expressions.

###
examples

Examples

```
defn tanh_grad(t) do
grad(t, &Nx.tanh/1) |> inspect_expr()
end
```

When invoked, it will print the expression being built by `defn`

:

```
#Nx.Tensor<
Nx.Defn.Expr
parameter a s64
parameter c s64
b = tanh [ a ] f64
d = power [ c, 2 ] s64
e = add [ b, d ] f64
>
```

Inspects the value at runtime to the terminal.

This function is implemented on top of `hook/3`

and therefore
has the following restrictions:

- It can only inspect tensors and
`Nx.Container`

- The return value of this function must be part of the output

All options are passed to `IO.inspect/2`

.

##
examples

Examples

```
defn tanh_grad(t) do
grad(t, fn t ->
t
|> Nx.tanh()
|> inspect_value()
end)
end
defn tanh_grad(t) do
grad(t, fn t ->
t
|> Nx.tanh()
|> inspect_value(label: "tanh")
end)
end
```

Ensures the first argument is a `keyword`

with the given
keys and default values.

The second argument must be a list of atoms, specifying
a given key, or tuples specifying a key and a default value.
If any of the keys in the `keyword`

is not defined on
`values`

, it raises an error.

##
examples

Examples

```
iex> keyword!([], [one: 1, two: 2]) |> Enum.sort()
[one: 1, two: 2]
iex> keyword!([two: 3], [one: 1, two: 2]) |> Enum.sort()
[one: 1, two: 3]
```

If atoms are given, they are supported as keys but do not provide a default value:

```
iex> keyword!([], [:one, two: 2]) |> Enum.sort()
[two: 2]
iex> keyword!([one: 1], [:one, two: 2]) |> Enum.sort()
[one: 1, two: 2]
```

Passing an unknown key raises:

```
iex> keyword!([three: 3], [one: 1, two: 2])
** (ArgumentError) unknown key :three in [three: 3], expected one of [:one, :two]
```

Element-wise maximum operation.

It delegates to `Nx.max/2`

(supports broadcasting).

##
examples

Examples

```
defn min_max(a, b) do
{min(a, b), max(a, b)}
end
```

Element-wise minimum operation.

It delegates to `Nx.min/2`

(supports broadcasting).

##
examples

Examples

```
defn min_max(a, b) do
{min(a, b), max(a, b)}
end
```

Element-wise logical NOT operation.

Zero is considered false, all other numbers are considered true.

It delegates to `Nx.logical_not/1`

.

##
examples

Examples

`defn logical_not(a), do: not a`

Element-wise logical OR operation.

Zero is considered false, all other numbers are considered true.

It delegates to `Nx.logical_or/2`

(supports broadcasting).

##
examples

Examples

```
defn and_or(a, b) do
{a and b, a or b}
end
```

Element-wise remainder operation.

It delegates to `Nx.remainder/2`

(supports broadcasting).

##
examples

Examples

```
defn divides_by_5?(a) do
rem(a, 5)
|> Nx.any()
|> Nx.equal(Nx.tensor(1))
end
```

Requires a module in order to use its macros, as in `Kernel.SpecialForms.require/2`

.

##
examples

Examples

```
defn some_fun(t) do
require NumericalMacros
NumericalMacros.some_macro t do
...
end
end
```

Rewrites the types of `expr`

recursively according to `opts`

##
options

Options

`:max_unsigned_type`

- replaces all signed tensors with size equal to or greater then the given type by the given type`:max_signed_type`

- replaces all signed tensors with size equal to or greater then the given type by the given type`:max_float_type`

- replaces all float tensors with size equal to or greater then the given type by the given type

##
examples

Examples

`rewrite_types(expr, max_float_type: {:f, 32})`

Stops computing the gradient for the given expression.

It effectively annotates the gradient for the given expression is 1.0.

##
examples

Examples

`expr = stop_grad(expr)`

Pipes `value`

to the given `fun`

and returns the `value`

itself.

Useful for running synchronous side effects in a pipeline.

##
examples

Examples

Let's suppose you want to inspect an expression in the middle of a pipeline. You could write:

```
a
|> Nx.add(b)
|> tap(&inspect_expr/1)
|> Nx.multiply(c)
```

Pipes `value`

into the given `fun`

.

In other words, it invokes `fun`

with `value`

as argument.
This is most commonly used in pipelines, allowing you
to pipe a value to a function outside of its first argument.

###
examples

Examples

```
a
|> Nx.add(b)
|> then(&Nx.subtract(c, &1))
```

Defines a transform that executes the given `fun`

with `arg`

when building `defn`

expressions.

##
example

Example

Take the following defn expression:

```
defn tanh_power(a, b) do
Nx.tanh(a) + Nx.power(b, 2)
end
```

Let's see a trivial example, which is to use `IO.inspect/1`

to
print a tensor expression at definition time:

```
defn tanh_power(a, b) do
Nx.tanh(a) + Nx.power(b, 2) |> transform(&IO.inspect/1)
end
```

Or:

```
defn tanh_power(a, b) do
res = Nx.tanh(a) + Nx.power(b, 2)
transform(res, &IO.inspect/1)
res
end
```

When invoked in both cases, it will print the expression being built
by `defn`

:

```
#Nx.Defn.Expr<
parameter a
parameter c
b = tanh [ a ] ()
d = power [ c, 2 ] ()
e = add [ b, d ] ()
>
```

Although, for convenience, you might use `inspect_expr/2`

instead.

##
pitfalls

Pitfalls

Because `transform/2`

is invoked inside `defn`

, its scope is tied
to `defn`

. For example, if you do this:

```
transform(tensor, fn tensor ->
if Nx.type(tensor) != {:f, 32} do
raise "bad"
end
end)
```

it won't work because it will use the `!=`

operator defined in
this module, which only works with tensors, instead of the operator
defined in Elixir's `Kernel`

. Therefore, we recommend all `transform/2`

calls to simply dispatch to a separate function. The example above
could be rewritten as:

`transform(tensor, &assert_2x2_shape(&1))`

where:

```
defp assert_2x2_shape(tensor) do
if Nx.shape(tensor) != {2, 2} do
raise "bad"
end
end
```

Defines a `while`

loop.

It expects the `initial`

arguments, a `condition`

expression, and
a `block`

:

```
while initial, condition do
block
end
```

`condition`

must return a scalar tensor where 0 is false and any
other number is true. The given `block`

will be executed while
`condition`

is true. Each invocation of `block`

must return a
value in the same shape as `initial`

arguments.

`while`

will return the value of the last execution of `block`

.
If `block`

is never executed because the initial `condition`

is
false, it returns `initial`

.

##
examples

Examples

A simple loop that increments `x`

until it is `10`

can be written as:

```
while x = 0, Nx.less(x, 10) do
x + 1
end
```

However, it is important to note that all variables you intend
to use inside the "while" must be explicitly given as argument
to "while". For example, imagine the amount we want to increment
by in the example above is given by a variable `y`

. The following
example is invalid:

```
while x = 0, Nx.less(x, 10) do
x + y
end
```

Instead, both `x`

and `y`

must be passed as variables to `while`

:

```
while {x = 0, y}, Nx.less(x, 10) do
{x + y, y}
end
```

Similarly, to compute the factorial of `x`

using `while`

:

```
defn factorial(x) do
{factorial, _} =
while {factorial = 1, x}, Nx.greater(x, 1) do
{factorial * x, x - 1}
end
factorial
end
```

Pipes the argument on the left to the function call on the right.

It delegates to `Kernel.|>/2`

.

##
examples

Examples

```
defn exp_sum(t) do
t
|> Nx.exp()
|> Nx.sum()
end
```

Element-wise bitwise OR operation.

Only integer tensors are supported.
It delegates to `Nx.bitwise_or/2`

(supports broadcasting).

##
examples

Examples

```
defn and_or(a, b) do
{a &&& b, a ||| b}
end
```

Element-wise bitwise not operation.

Only integer tensors are supported.
It delegates to `Nx.bitwise_not/1`

.

##
examples

Examples

`defn bnot(a), do: ~~~a`