# View Source Nx.Defn.Kernel(Nx v0.5.2)

All imported functionality available inside defn blocks.

This module can be used in defn.

# Link to this section Summary

## Functions

Element-wise bitwise AND operation.

Element-wise power operator.

Element-wise multiplication operator.

Element-wise unary plus operator.

Element-wise unary plus operator.

Element-wise subtraction operator.

Creates the full-slice range 0..-1//1.

Builds a range.

Builds a range with step.

Element-wise division operator.

Element-wise inequality operation.

Element-wise less than operation.

Element-wise left shift operation.

Element-wise less-equal operation.

Concatenates two strings.

Element-wise equality operation.

Element-wise greater than operation.

Element-wise greater-equal operation.

Element-wise right shift operation.

Reads a module attribute at compilation time.

Element-wise logical AND operation.

Asserts the keyword list has the given keys.

Attaches a token to an expression. See hook/3.

Pattern matches the result of expr against the given clauses.

Evaluates the expression corresponding to the first clause that evaluates to a truthy value.

Creates a token for hooks. See hook/3.

Defines a custom gradient for the given expression.

Element-wise quotient operator.

Gets the element at the zero-based index in tuple.

Shortcut for hook/3.

Defines a hook.

Defines a hook with an existing token. See hook/3.

Provides if/else expressions.

Imports functions and macros into the current scope, as in Kernel.SpecialForms.import/2.

Converts the given expression into a string.

Ensures the first argument is a keyword with the given keys and default values.

Element-wise maximum operation.

Element-wise minimum operation.

Element-wise logical NOT operation.

Element-wise logical OR operation.

Prints the given expression to the terminal.

Prints the value at runtime to the terminal.

Raises a runtime exception with the given message.

Raises an exception with the given arguments.

Element-wise remainder operation.

Requires a module in order to use its macros, as in Kernel.SpecialForms.require/2.

Rewrites the types of expr recursively according to opts

Stops computing the gradient for the given expression.

Pipes value to the given fun and returns the value itself.

Pipes value into the given fun.

Defines a while loop.

Pipes the argument on the left to the function call on the right.

Element-wise bitwise OR operation.

Element-wise bitwise not operation.

# left &&& right

View Source

Element-wise bitwise AND operation.

Only integer tensors are supported. It delegates to Nx.bitwise_and/2 (supports broadcasting).

## examples Examples

defn and_or(a, b) do
{a &&& b, a ||| b}
end

# left ** right

View Source

Element-wise power operator.

It delegates to Nx.pow/2 (supports broadcasting).

## examples Examples

defn pow(a, b) do
a ** b
end

# left * right

View Source

Element-wise multiplication operator.

It delegates to Nx.multiply/2 (supports broadcasting).

## examples Examples

defn multiply(a, b) do
a * b
end

# +tensor

View Source (macro)

Element-wise unary plus operator.

Simply returns the given argument.

## examples Examples

defn plus_and_minus(a) do
{+a, -a}
end

# left + right

View Source

It delegates to Nx.add/2 (supports broadcasting).

## examples Examples

defn add(a, b) do
a + b
end

# -tensor

View Source (macro)

Element-wise unary plus operator.

It delegates to Nx.negate/1.

## examples Examples

defn plus_and_minus(a) do
{+a, -a}
end

# left - right

View Source

Element-wise subtraction operator.

It delegates to Nx.subtract/2 (supports broadcasting).

## examples Examples

defn subtract(a, b) do
a - b
end

# ..

View Source

Creates the full-slice range 0..-1//1.

This function returns a range with the following properties:

• When enumerated, it is empty

• When used as a slice, it returns the sliced element as is

## examples Examples

iex> t = Nx.tensor([1, 2, 3])
iex> t[..]
#Nx.Tensor<
s64[3]
[1, 2, 3]
>

# first..last

View Source

Builds a range.

Ranges are inclusive and both sides must be integers.

The step of the range is computed based on the first and last values of the range.

## examples Examples

iex> t = Nx.tensor([1, 2, 3])
iex> t[1..2]
#Nx.Tensor<
s64[2]
[2, 3]
>

# first..last//step

View Source

Builds a range with step.

Ranges are inclusive and both sides must be integers.

## examples Examples

iex> t = Nx.tensor([1, 2, 3])
iex> t[1..2//1]
#Nx.Tensor<
s64[2]
[2, 3]
>

# left / right

View Source

Element-wise division operator.

It delegates to Nx.divide/2 (supports broadcasting).

## examples Examples

defn divide(a, b) do
a / b
end

# left != right

View Source (macro)

Element-wise inequality operation.

It delegates to Nx.not_equal/2.

## examples Examples

defn check_inequality(a, b) do
a != b
end

# left < right

View Source (macro)

Element-wise less than operation.

It delegates to Nx.less/2.

## examples Examples

defn check_less_than(a, b) do
a < b
end

# left <<< right

View Source

Element-wise left shift operation.

Only integer tensors are supported. It delegates to Nx.left_shift/2 (supports broadcasting).

## examples Examples

defn shift_left_and_right(a, b) do
{a <<< b, a >>> b}
end

# left <= right

View Source (macro)

Element-wise less-equal operation.

It delegates to Nx.less_equal/2.

## examples Examples

defn check_less_equal(a, b) do
a <= b
end

# left <> right

View Source (macro)

Concatenates two strings.

Equivalent to Kernel.<>/2.

# left == right

View Source (macro)

Element-wise equality operation.

It delegates to Nx.equal/2.

## examples Examples

defn check_equality(a, b) do
a == b
end

# left > right

View Source (macro)

Element-wise greater than operation.

It delegates to Nx.greater/2.

## examples Examples

defn check_greater_than(a, b) do
a > b
end

# left >= right

View Source (macro)

Element-wise greater-equal operation.

It delegates to Nx.greater_equal/2.

## examples Examples

defn check_greater_equal(a, b) do
a >= b
end

# left >>> right

View Source

Element-wise right shift operation.

Only integer tensors are supported. It delegates to Nx.right_shift/2 (supports broadcasting).

## examples Examples

defn shift_left_and_right(a, b) do
{a <<< b, a >>> b}
end

# @expr

View Source (macro)

Reads a module attribute at compilation time.

It is useful to inject code constants into defn. It delegates to Kernel.@/1.

## examples Examples

@two_per_two Nx.tensor([[1, 2], [3, 4]])
defn add_2x2_attribute(t), do: t + @two_per_two

# alias(module, opts \\ [])

View Source (macro)

Defines an alias, as in Kernel.SpecialForms.alias/2.

An alias allows you to refer to a module using its aliased name. For example:

defn some_fun(t) do
alias Math.Helpers, as: MH
MH.fft(t)
end

If the :as option is not given, the alias defaults to the last part of the given alias. For example,

alias Math.Helpers

is equivalent to:

alias Math.Helpers, as: Helpers

Finally, note that aliases define outside of a function also apply to the function, as they have lexical scope:

alias Math.Helpers, as: MH

defn some_fun(t) do
MH.fft(t)
end

# left and right

View Source (macro)

Element-wise logical AND operation.

Zero is considered false, all other numbers are considered true.

It delegates to Nx.logical_and/2 (supports broadcasting).

## examples Examples

defn and_or(a, b) do
{a and b, a or b}
end

# assert_keys(keyword, keys)

View Source

Asserts the keyword list has the given keys.

If it succeeds, it returns the given keyword list. Raises an error otherwise.

## examples Examples

To assert the tensor is a scalar, you can pass the empty tuple shape:

iex> assert_keys([one: 1, two: 2], [:one, :two])
[one: 1, two: 2]

If the keys are not available, an error is raised:

iex> assert_keys([one: 1, two: 2], [:three])
** (ArgumentError) expected key :three in keyword list, got: [one: 1, two: 2]

# attach_token(token, expr)

View Source

Attaches a token to an expression. See hook/3.

# case(expr, list)

View Source (macro)

Pattern matches the result of expr against the given clauses.

For example:

case Nx.shape(tensor) do
{_} -> implementation_for_rank_one(tensor)
{_, _} -> implementation_for_rank_two(tensor)
_ -> implementation_for_rank_n(tensor)
end

Opposite to cond/2 and if/2, which can execute the branching in the device, cases are always expanded when building the expression, and never on the device. This allows case/2 to work very similarly to Elixir's own Kernel.SpecialForms.case/2, with only the following restrictions in place:

• case inside defn only accepts structs, atoms, integers, and tuples as arguments
• case can match on struct names but not on its fields
• guards in case inside defn can only access variables defined within the pattern

Here is an example of case with guards:

case Nx.shape(tensor) do
{x, y} when x > y -> implementation_for_tall(tensor)
{x, y} when x < y -> implementation_for_wide(tensor)
{x, x} -> implementation_for_square(tensor)
end

# cond(opts)

View Source (macro)

Evaluates the expression corresponding to the first clause that evaluates to a truthy value.

It has the format of:

cond do
condition1 ->
expr1

condition2 ->
expr2

true ->
expr3
end

The conditions must be a scalar. Zero is considered false, any other number is considered true. The booleans false and true are supported, but any other value will raise.

All clauses are normalized to the same type and are broadcast to the same shape. The last condition must always evaluate to true. All clauses are executed in the device, unless they can be determined to always be true/false while building the numerical expression.

## examples Examples

cond do
Nx.all(Nx.greater(a, 0)) -> b * c
Nx.all(Nx.less(a, 0)) -> b + c
true -> b - c
end

When a defn is invoked, all cond clauses are traversed and expanded in order to build their expressions. This means that, if you attempt to raise in any clause, then it will always raise. You can only raise in limited situations inside defn, see raise/2 for more information.

# create_token()

View Source

Creates a token for hooks. See hook/3.

View Source

Defines a custom gradient for the given expression.

It also expects a list of inputs of the gradient and a fun to compute the gradient. The function will be called with the current gradient. It must return a list of arguments and their updated gradient to continue applying grad on.

## examples Examples

For example, if the gradient of cos(t) were to be implemented by hand:

def cos(t) do
[-g * Nx.sin(t)]
end)
end

# div(left, right)

View Source

Element-wise quotient operator.

It delegates to Nx.quotient/2 (supports broadcasting).

## examples Examples

defn quotient(a, b) do
div(a, b)
end

# elem(tuple, index)

View Source

Gets the element at the zero-based index in tuple.

It raises ArgumentError when index is negative or it is out of range of the tuple elements.

## examples Examples

iex> tuple = {1, 2, 3}
iex> elem(tuple, 0)
1

# hook(expr, name_or_function)

View Source

Shortcut for hook/3.

# hook(expr, name, function)

View Source

Defines a hook.

Hooks are a mechanism to execute an anonymous function for side-effects with runtime tensor values.

Let's see an example:

defmodule Hooks do
import Nx.Defn

mult = hook(a * b, fn tensor -> IO.inspect({:mult, tensor}) end)
end
end

Note a hook can only access the variables passed as arguments to the hook. It cannot access any other variable defined in defn outside of the hook.

The defn above defines two hooks, one is called with the value of a + b and another with a * b. Once you invoke the function above, you should see this printed:

Hooks.add_and_mult(2, 3)
s64
5
>}
{:mult, #Nx.Tensor<
s64
6
>}

In other words, the hook function accepts a tensor expression as argument and it will invoke a custom function with a tensor value at runtime. hook returns the result of the given expression. The expression can be any tensor or a Nx.Container.

Note you must return the result of the hook call. For example, the code below won't inspect the :add tuple, because the hook is not returned from defn:

defn add_and_mult(a, b) do
mult = hook(a * b, fn tensor -> IO.inspect({:mult, tensor}) end)
mult
end

We will learn how to hook into a value that is not part of the result in the "Hooks and tokens" section.

## named-hooks Named hooks

It is possible to give names to the hooks. This allows them to be defined or overridden by calling Nx.Defn.jit/2 or Nx.Defn.stream/2. Let's see an example:

defmodule Hooks do
import Nx.Defn

mult = hook(a * b, :hooks_mult)
end
end

Now you can pass the hook as argument as follows:

hooks = %{
end
}

fun.(Nx.tensor(2), Nx.tensor(3))

Important! We recommend to prefix your hook names by the name of your project to avoid conflicts.

If a named hook is not given, compilers can optimize that away and not transfer the tensor from the device in the first place.

You can also mix named hooks with callbacks:

defn add_and_mult(a, b) do
mult = hook(a * b, :hooks_mult, fn tensor -> IO.inspect({:mult, tensor}) end)
end

If a hook with the same name is given to Nx.Defn.jit/2 or Nx.Defn.stream/2, then it will override the default callback.

## hooks-and-tokens Hooks and tokens

So far, we have always returned the result of the hook call. However, what happens if the values we want to hook are not part of the return value, such as below?

defn add_and_mult(a, b) do
mult = hook(a * b, :hooks_mult, &IO.inspect({:mult, &1}))
mult
end

In such cases, you must use tokens. Tokens are used to create an ordering over hooks, ensuring hooks execute in a certain sequence:

defn add_and_mult(a, b) do
token = create_token()
{token, mult} = hook_token(token, a * b, :hooks_mult, &IO.inspect({:mult, &1}))
attach_token(token, mult)
end

The example above creates a token and uses hook_token/4 to create hooks attached to their respective tokens. By using a token, we guarantee that those hooks will be invoked in the order in which they were defined. Then, at the end of the function, we attach the token (and its associated hooks) to the result mult.

In fact, the hook/3 function is implemented roughly like this:

def hook(tensor_expr, name, function) do
{token, result} = hook_token(create_token(), tensor_expr, name, function)
attach_token(token, result)
end

Note you must attach the token at the end, otherwise the hooks will be "lost", as if they were not defined. This also applies to conditionals and loops. The token must be attached within the branch they are used. For example, this won't work:

token = create_token()

{token, result} =
if Nx.any(value) do
hook_token(token, some_value)
else
hook_token(token, another_value)
end

attach_token(token, result)

token = create_token()

if Nx.any(value) do
{token, result} = hook_token(token, some_value)
attach_token(token, result)
else
{token, result} = hook_token(token, another_value)
attach_token(token, result)
end

# hook_token(token, expr, name_or_function)

View Source

Shortcut for hook_token/4.

# hook_token(token, expr, name, function)

View Source

Defines a hook with an existing token. See hook/3.

# if(pred, do_else)

View Source (macro)

Provides if/else expressions.

The first argument must be a scalar. Zero is considered false, any other number is considered true. The booleans false and true are supported, but any other value will raise.

The second argument is a keyword list with do and else blocks. The sides are broadcast to return the same shape and normalized to return the same type.

## examples Examples

if Nx.any(Nx.equal(t, 0)) do
0.0
else
1 / t
end

In case else is not given, it is assumed to be 0 with the same as the do clause. If you want to nest multiple conditionals, see cond/1 instead.

When a defn is invoked, both do/else clauses are traversed and expanded in order to build their expressions. This means that, if you attempt to raise in any clause, then it will always raise. You can only raise in limited situations inside defn, see raise/2 for more information.

# import(module, opts \\ [])

View Source (macro)

Imports functions and macros into the current scope, as in Kernel.SpecialForms.import/2.

Imports are typically discouraged in favor of alias/2.

## examples Examples

defn some_fun(t) do
import Math.Helpers
fft(t)
end

# inspect(expr, opts \\ [])

View Source

Converts the given expression into a string.

inspect/2 is used to convert expressions into strings, typically to be used as part of error messages. If you want to inspect for debugging, consider using print_expr/2, to print the underlying expression, or print_value/2 to print the value during execution.

defn square_shape(tensor) do
case Nx.shape(tensor) do
{n, n} -> n
shape -> raise ArgumentError, "expected a square tensor: #{inspect(shape)}"
end
end

# keyword!(keyword, values)

View Source

Ensures the first argument is a keyword with the given keys and default values.

The second argument must be a list of atoms, specifying a given key, or tuples specifying a key and a default value. If any of the keys in the keyword is not defined on values, it raises an error.

This does not validate required keys. For such, use assert_keys/2 instead.

This is equivalent to Elixir's Keyword.validate!/2.

## examples Examples

iex> keyword!([], [one: 1, two: 2]) |> Enum.sort()
[one: 1, two: 2]

iex> keyword!([two: 3], [one: 1, two: 2]) |> Enum.sort()
[one: 1, two: 3]

If atoms are given, they are supported as keys but do not provide a default value:

iex> keyword!([], [:one, two: 2]) |> Enum.sort()
[two: 2]

iex> keyword!([one: 1], [:one, two: 2]) |> Enum.sort()
[one: 1, two: 2]

Passing an unknown key raises:

iex> keyword!([three: 3], [one: 1, two: 2])
** (ArgumentError) unknown key :three in [three: 3], expected one of [:one, :two]

# max(left, right)

View Source

Element-wise maximum operation.

It delegates to Nx.max/2 (supports broadcasting).

## examples Examples

defn min_max(a, b) do
{min(a, b), max(a, b)}
end

# min(left, right)

View Source

Element-wise minimum operation.

It delegates to Nx.min/2 (supports broadcasting).

## examples Examples

defn min_max(a, b) do
{min(a, b), max(a, b)}
end

# not tensor

View Source (macro)

Element-wise logical NOT operation.

Zero is considered false, all other numbers are considered true.

It delegates to Nx.logical_not/1.

## examples Examples

defn logical_not(a), do: not a

# left or right

View Source (macro)

Element-wise logical OR operation.

Zero is considered false, all other numbers are considered true.

It delegates to Nx.logical_or/2 (supports broadcasting).

## examples Examples

defn and_or(a, b) do
{a and b, a or b}
end

# raise(message)

View Source (macro)

Raises a runtime exception with the given message.

See raise/2 for more information on exceptions inside defn.

# raise(exception, arguments)

View Source (macro)

Raises an exception with the given arguments.

raise/2 is invoked while building the numerical expression, not inside the device. This means that raise may be invoked on unexpected situations, as we build the numerical expression. To better understand those cases, let's see some examples.

First, let's start with a valid use case for raise/2: raise on mismatched shapes. Inside defn, we know the tensor shapes and types, but not their values, so we can assert on the shape while building the numerical expression:

defn square_shape(tensor) do
case Nx.shape(tensor) do
{n, n} -> n
shape -> raise ArgumentError, "expected a square tensor: #{inspect(shape)}"
end
end

In the example above, only the matching branch of the case is executed, so if you give it a 2x2 tensor, it will return 2. However, if you give it a non-square tensor, it will raise.

Now consider this code:

defn some_check(a, b) do
if a != b do
a * b
else
raise "expected different tensors, got: #{inspect(a)} and #{inspect(b)}"
end
end

In this case, both a and b are tensors and we are comparing their values. However, their values are unknown, which means we need to convert the whole if to a numerical expression and run it on the device. Therefore, once we convert the else branch, it will execute raise/2, making it so the code above always raises!

In such cases, there are no alternatives. We can't execute exceptions in the CPU/GPU, so you need to approach the problem under a different perspective.

# rem(left, right)

View Source

Element-wise remainder operation.

It delegates to Nx.remainder/2 (supports broadcasting).

## examples Examples

defn divides_by_5?(a) do
rem(a, 5)
|> Nx.any()
|> Nx.equal(Nx.tensor(1))
end

# require(module, opts \\ [])

View Source (macro)

Requires a module in order to use its macros, as in Kernel.SpecialForms.require/2.

## examples Examples

defn some_fun(t) do
require NumericalMacros

NumericalMacros.some_macro t do
...
end
end

# rewrite_types(expr, opts)

View Source

Rewrites the types of expr recursively according to opts

## options Options

• :max_unsigned_type - replaces all signed tensors with size equal to or greater then the given type by the given type

• :max_signed_type - replaces all signed tensors with size equal to or greater then the given type by the given type

• :max_float_type - replaces all float tensors with size equal to or greater then the given type by the given type

## examples Examples

rewrite_types(expr, max_float_type: {:f, 32})

View Source

Stops computing the gradient for the given expression.

It effectively annotates the gradient for the given expression is 1.0.

## examples Examples

expr = stop_grad(expr)

# tap(value, fun)

View Source (macro)

Pipes value to the given fun and returns the value itself.

Useful for running synchronous side effects in a pipeline.

## examples Examples

Let's suppose you want to inspect an expression in the middle of a pipeline. You could write:

a
|> tap(&print_expr/1)
|> Nx.multiply(c)

# then(value, fun)

View Source (macro)

Pipes value into the given fun.

In other words, it invokes fun with value as argument. This is most commonly used in pipelines, allowing you to pipe a value to a function outside of its first argument.

## examples Examples

a
|> then(&Nx.subtract(c, &1))

# while(initial, condition_or_generator, opts \\ [], do_block)

View Source (macro)

Defines a while loop.

It expects the initial arguments, a condition expression, and a block:

while initial, condition do
block
end

condition must return a scalar tensor where 0 is false and any other number is true. The given block will be executed while condition is true. Each invocation of block must return a value in the same shape as initial arguments.

while will return the value of the last execution of block. If block is never executed because the initial condition is false, it returns initial.

Note: you must prefer to use the operations in the Nx module, whenever available, instead of writing your own loops.

## examples Examples

A simple loop that increments x until it is 10 can be written as:

while x = 0, Nx.less(x, 10) do
x + 1
end

However, it is important to note that all variables you intend to use inside the "while" must be explicitly given as argument to "while". For example, imagine the amount we want to increment by in the example above is given by a variable y. The following example is invalid:

while x = 0, Nx.less(x, 10) do
x + y
end

Instead, both x and y must be passed as variables to while:

while {x = 0, y}, Nx.less(x, 10) do
{x + y, y}
end

Similarly, to compute the factorial of x using while:

defn factorial(x) do
{factorial, _} =
while {factorial = 1, x}, Nx.greater(x, 1) do
{factorial * x, x - 1}
end

factorial
end

## generators Generators

Inspired by Elixir's for-comprehensions, while in defn supports generators. Generators may be tensors or ranges.

### tensor-generators Tensor generators

When the generator is a tensor, Nx will traverse its highest dimension. For example, you could sum a one dimensional tensor as follows:

while acc = 0, i <- tensor do
acc + i
end

Note: implementing sum using while, as above, is done as an example. In practice, you must prefer to use the operations in the Nx module, whenever available, instead of writing your own loops.

One advantage of using generators is that you can also unroll the loop for performance:

while acc = 0, i <- tensor, unroll: true do
acc + i
end

Or unroll it in batches:

while acc = 0, i <- tensor, unroll: 4 do
acc + i
end

Unrolling means that the the while body is automatically duplicated a certain amount of times, as if you wrote all iterations by hand. This makes the final expression larger, which causes a longer compilation time, however it enables additional compile-time optimizations (such as fusion), improving the runtime efficiency.

### range-generators Range generators

A range can also be given as a generator. The range may be increasing or decreasing. Also remember that ranges in Elixir are inclusive on both begin and end. The sum example from the previous section could also be written with ranges:

while {tensor, acc = 0}, i <- 0..Nx.axis_size(tensor, 0)-1 do
acc + tensor[i]
end

# left |> right

View Source (macro)

Pipes the argument on the left to the function call on the right.

It delegates to Kernel.|>/2.

## examples Examples

defn exp_sum(t) do
t
|> Nx.exp()
|> Nx.sum()
end

# left ||| right

View Source

Element-wise bitwise OR operation.

Only integer tensors are supported. It delegates to Nx.bitwise_or/2 (supports broadcasting).

## examples Examples

defn and_or(a, b) do
{a &&& b, a ||| b}
end
Only integer tensors are supported. It delegates to Nx.bitwise_not/1.
defn bnot(a), do: ~~~a