View Source Nx.Defn.Kernel (Nx v0.2.0)

All imported functionality available inside defn blocks.

Link to this section Summary

Functions

Element-wise inequality operation.

Element-wise bitwise AND operation.

Element-wise multiplication operator.

Element-wise multiplication operator.

Element-wise unary plus operator.

Element-wise addition operator.

Element-wise unary plus operator.

Element-wise subtraction operator.

Builds a range.

Builds a range with step.

Element-wise division operator.

Element-wise less than operation.

Element-wise left shift operation.

Element-wise less-equal operation.

Element-wise equality operation.

Element-wise greater than operation.

Element-wise greater-equal operation.

Element-wise right shift operation.

Reads a module attribute at compilation time.

Element-wise logical AND operation.

Asserts the tensor has a certain shape.

Asserts the tensor has a certain shape pattern.

Attaches a token to an expression. See hook/3.

Evaluates the expression corresponding to the first clause that evaluates to a truthy value.

Creates a token for hooks. See hook/3.

Defines a custom gradient for the given expression.

Gets the element at the zero-based index in tuple.

Defines a hook.

Defines a hook with an existing token. See hook/3.

Provides if/else expressions.

Imports functions and macros into the current scope, as in Kernel.SpecialForms.import/2.

Inspects the given expression to the terminal.

Inspects the value at runtime to the terminal.

Ensures the first argument is a keyword with the given keys and default values.

Element-wise maximum operation.

Element-wise minimum operation.

Element-wise logical NOT operation.

Element-wise logical OR operation.

Element-wise remainder operation.

Requires a module in order to use its macros, as in Kernel.SpecialForms.require/2.

Rewrites the types of expr recursively according to opts

Stops computing the gradient for the given expression.

Pipes value to the given fun and returns the value itself.

Pipes value into the given fun.

Defines a transform that executes the given fun with arg when building defn expressions.

Defines a while loop.

Pipes the argument on the left to the function call on the right.

Element-wise bitwise OR operation.

Element-wise bitwise not operation.

Link to this section Functions

Element-wise inequality operation.

It delegates to Nx.not_equal/2.

examples

Examples

defn check_inequality(a, b) do
  a != b
end

Element-wise bitwise AND operation.

Only integer tensors are supported. It delegates to Nx.bitwise_and/2 (supports broadcasting).

examples

Examples

defn and_or(a, b) do
  {a &&& b, a ||| b}
end

Element-wise multiplication operator.

It delegates to Nx.power/2 (supports broadcasting).

examples

Examples

defn power(a, b) do
  a ** b
end

Element-wise multiplication operator.

It delegates to Nx.multiply/2 (supports broadcasting).

examples

Examples

defn multiply(a, b) do
  a * b
end

Element-wise unary plus operator.

Simply returns the given argument.

examples

Examples

defn plus_and_minus(a) do
  {+a, -a}
end

Element-wise addition operator.

It delegates to Nx.add/2 (supports broadcasting).

examples

Examples

defn add(a, b) do
  a + b
end

Element-wise unary plus operator.

It delegates to Nx.negate/1.

examples

Examples

defn plus_and_minus(a) do
  {+a, -a}
end

Element-wise subtraction operator.

It delegates to Nx.subtract/2 (supports broadcasting).

examples

Examples

defn subtract(a, b) do
  a - b
end

Builds a range.

Ranges are inclusive and both sides must be integers.

The step of the range is computed based on the first and last values of the range.

examples

Examples

iex> t = Nx.tensor([1, 2, 3])
iex> t[1..2]
#Nx.Tensor<
  s64[2]
  [2, 3]
>

Builds a range with step.

Ranges are inclusive and both sides must be integers.

examples

Examples

iex> t = Nx.tensor([1, 2, 3])
iex> t[1..2//1]
#Nx.Tensor<
  s64[2]
  [2, 3]
>

Element-wise division operator.

It delegates to Nx.divide/2 (supports broadcasting).

examples

Examples

defn divide(a, b) do
  a / b
end

Element-wise less than operation.

It delegates to Nx.less/2.

examples

Examples

defn check_less_than(a, b) do
  a < b
end

Element-wise left shift operation.

Only integer tensors are supported. It delegates to Nx.left_shift/2 (supports broadcasting).

examples

Examples

defn shift_left_and_right(a, b) do
  {a <<< b, a >>> b}
end

Element-wise less-equal operation.

It delegates to Nx.less_equal/2.

examples

Examples

defn check_less_equal(a, b) do
  a <= b
end

Element-wise equality operation.

It delegates to Nx.equal/2.

examples

Examples

defn check_equality(a, b) do
  a == b
end

Element-wise greater than operation.

It delegates to Nx.greater/2.

examples

Examples

defn check_greater_than(a, b) do
  a > b
end

Element-wise greater-equal operation.

It delegates to Nx.greater_equal/2.

examples

Examples

defn check_greater_equal(a, b) do
  a >= b
end

Element-wise right shift operation.

Only integer tensors are supported. It delegates to Nx.right_shift/2 (supports broadcasting).

examples

Examples

defn shift_left_and_right(a, b) do
  {a <<< b, a >>> b}
end

Reads a module attribute at compilation time.

It is useful to inject code constants into defn. It delegates to Kernel.@/1.

examples

Examples

@two_per_two Nx.tensor([[1, 2], [3, 4]])
defn add_2x2_attribute(t), do: t + @two_per_two
Link to this macro

alias(module, opts \\ [])

View Source (macro)

Defines an alias, as in Kernel.SpecialForms.alias/2.

An alias allows you to refer to a module using its aliased name. For example:

defn some_fun(t) do
  alias Math.Helpers, as: MH
  MH.fft(t)
end

If the :as option is not given, the alias defaults to the last part of the given alias. For example,

alias Math.Helpers

is equivalent to:

alias Math.Helpers, as: Helpers

Finally, note that aliases define outside of a function also apply to the function, as they have lexical scope:

alias Math.Helpers, as: MH

defn some_fun(t) do
  MH.fft(t)
end

Element-wise logical AND operation.

Zero is considered false, all other numbers are considered true.

It delegates to Nx.logical_and/2 (supports broadcasting).

examples

Examples

defn and_or(a, b) do
  {a and b, a or b}
end
Link to this function

assert_shape(tensor, shape)

View Source

Asserts the tensor has a certain shape.

If it succeeds, it returns the given tensor. Raises an error otherwise.

examples

Examples

To assert the tensor is a scalar, you can pass the empty tuple shape:

iex> assert_shape Nx.tensor(13), {}
#Nx.Tensor<
  s64
  13
>

If the shapes do not match, an error is raised:

iex> assert_shape Nx.tensor([1, 2, 3]), {}
** (ArgumentError) expected tensor to be a scalar, got tensor with shape {3}

iex> assert_shape Nx.tensor([1, 2, 3]), {4}
** (ArgumentError) expected tensor to have shape {4}, got tensor with shape {3}

If you want to assert on the rank or shape patterns, use assert_shape_pattern/2 instead.

Link to this macro

assert_shape_pattern(tensor, shape)

View Source (macro)

Asserts the tensor has a certain shape pattern.

If it succeeds, it returns the given tensor. Raises an error otherwise.

examples

Examples

Opposite to assert_shape/2, where the given shape is a value, assert_shape_pattern allows the shape to be any Elixir pattern. We can use this to match on ranks:

iex> assert_shape_pattern Nx.tensor([[1, 2], [3, 4]]), {_, _}
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>

iex> assert_shape_pattern Nx.tensor([1, 2, 3]), {_, _}
** (ArgumentError) expected tensor to match shape {_, _}, got tensor with shape {3}

Or even use variables to assert on properties such as square matrices:

iex> assert_shape_pattern Nx.tensor([[1, 2], [3, 4]]), {x, x}
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>

iex> assert_shape_pattern Nx.tensor([1, 2, 3]), {x, x}
** (ArgumentError) expected tensor to match shape {x, x}, got tensor with shape {3}

You can also use guards to specify tall matrices and so forth:

iex> assert_shape_pattern Nx.tensor([[1], [2]]), {x, y} when x > y
#Nx.Tensor<
  s64[2][1]
  [
    [1],
    [2]
  ]
>

iex> assert_shape_pattern Nx.tensor([1, 2]), {x, y} when x > y
** (ArgumentError) expected tensor to match shape {x, y} when x > y, got tensor with shape {2}
Link to this function

attach_token(token, expr)

View Source

Attaches a token to an expression. See hook/3.

Evaluates the expression corresponding to the first clause that evaluates to a truthy value.

It has the format of:

cond do
  condition1 ->
    expr1

  condition2 ->
    expr2

  :otherwise ->
    expr3
end

The conditions must be a scalar. Zero is considered false, any other number is considered true.

All clauses are normalized to the same type and are broadcast to the same shape. The last condition must always evaluate to an atom, typically :otherwise.

examples

Examples

cond do
  Nx.all(Nx.greater(a, 0)) -> b *
  Nx.all(Nx.less(a, 0)) -> b + c
  true -> b - c
end

Creates a token for hooks. See hook/3.

Defines a custom gradient for the given expression.

It expects a fun to compute the gradient. The function will be called with the expression itself and the current gradient. It must return a list of arguments and their updated gradient to continue applying grad on.

examples

Examples

For example, if the gradient of cos(t) were to be implemented by hand:

def cos(t) do
  custom_grad(Nx.cos(t), fn _ans, g ->
    [{t, -g * Nx.sin(t)}]
  end)
end

Gets the element at the zero-based index in tuple.

It raises ArgumentError when index is negative or it is out of range of the tuple elements.

examples

Examples

iex> tuple = {1, 2, 3}
iex> elem(tuple, 0)
1
Link to this function

hook(expr, name_or_function)

View Source

Shortcut for hook/3.

Link to this function

hook(expr, name, function)

View Source

Defines a hook.

Hooks are a mechanism to execute an anonymous function for side-effects with runtime tensor values.

Let's see an example:

defmodule Hooks do
  import Nx.Defn

  defn add_and_mult(a, b) do
    add = hook(a + b, fn tensor -> IO.inspect({:add, tensor}) end)
    mult = hook(a * b, fn tensor -> IO.inspect({:mult, tensor}) end)
    {add, mult}
  end
end

The defn above defines two hooks, one is called with the value of a + b and another with a * b. Once you invoke the function above, you should see this printed:

Hooks.add_and_mult(2, 3)
{:add, #Nx.Tensor<
   s64
   5
>}
{:mult, #Nx.Tensor<
   s64
   6
>}

In other words, the hook function accepts a tensor expression as argument and it will invoke a custom function with a tensor value at runtime. hook returns the result of the given expression. The expression can be any tensor or a Nx.Container.

Note you must return the result of the hook call. For example, the code below won't inspect the :add tuple, because the hook is not returned from defn:

defn add_and_mult(a, b) do
  _add = hook(a + b, fn tensor -> IO.inspect({:add, tensor}) end)
  mult = hook(a * b, fn tensor -> IO.inspect({:mult, tensor}) end)
  mult
end

We will learn how to hook into a value that is not part of the result in the "Hooks and tokens" section.

named-hooks

Named hooks

It is possible to give names to the hooks. This allows them to be defined or overridden by calling Nx.Defn.jit/3 or Nx.Defn.stream/3. Let's see an example:

defmodule Hooks do
  import Nx.Defn

  defn add_and_mult(a, b) do
    add = hook(a + b, :hooks_add)
    mult = hook(a * b, :hooks_mult)
    {add, mult}
  end
end

Now you can pass the hook as argument as follows:

hooks = %{
  hooks_add: fn tensor ->
    IO.inspect {:add, tensor}
  end
}

args = [Nx.tensor(2), Nx.tensor(3)]
Nx.Defn.jit(&Hooks.add_and_mult/2, args, hooks: hooks)

Important! We recommend to prefix your hook names by the name of your project to avoid conflicts.

If a named hook is not given, compilers can optimize that away and not transfer the tensor from the device in the first place.

You can also mix named hooks with callbacks:

defn add_and_mult(a, b) do
  add = hook(a + b, :hooks_add, fn tensor -> IO.inspect({:add, tensor}) end)
  mult = hook(a * b, :hooks_mult, fn tensor -> IO.inspect({:mult, tensor}) end)
  {add, mult}
end

If a hook with the same name is given to Nx.Defn.jit/3 or Nx.Defn.stream/3, then it will override the default callback.

hooks-and-tokens

Hooks and tokens

So far, we have always returned the result of the hook call. However, what happens if the values we want to hook are not part of the return value, such as below?

defn add_and_mult(a, b) do
  _add = hook(a + b, :hooks_add, &IO.inspect({:add, &1}))
  mult = hook(a * b, :hooks_mult, &IO.inspect({:mult, &1}))
  mult
end

In such cases, you must use tokens. Tokens are used to create an ordering over hooks, ensuring hooks execute in a certain sequence:

defn add_and_mult(a, b) do
  token = create_token()
  {token, _add} = hook_token(token, a + b, :hooks_add, &IO.inspect({:add, &1}))
  {token, mult} = hook_token(token, a * b, :hooks_mult, &IO.inspect({:mult, &1}))
  attach_token(token, mult)
end

The example above creates a token and uses hook_token/4 to create hooks attached to their respective tokens. By using a token, we guarantee that those hooks will be invoked in the order in which they were defined. Then, at the end of the function, we attach the token (and its associated hooks) to the result mult.

In fact, the hook/3 function is implemented roughly like this:

def hook(tensor_expr, name, function) do
  {token, result} = hook_token(create_token(), tensor_expr, name, function)
  attach_token(token, result)
end

Note you must attach the token at the end, otherwise the hooks will be "lost", as if they were not defined. This also applies to conditionals and loops. The token must be attached within the branch they are used. For example, this won't work:

token = create_token()

{token, result} =
  if Nx.any(value) do
    hook_token(token, some_value)
  else
    hook_token(token, another_value)
  end

attach_token(result)

Instead, you must write:

token = create_token()

if Nx.any(value) do
  {token, result} = hook_token(token, some_value)
  attach_token(token, result)
else
  {token, result} = hook_token(token, another_value)
  attach_token(token, result)
end
Link to this function

hook_token(token, expr, name_or_function)

View Source

Shortcut for hook_token/4.

Link to this function

hook_token(token, expr, name, function)

View Source

Defines a hook with an existing token. See hook/3.

Link to this macro

if(pred, do_else)

View Source (macro)

Provides if/else expressions.

The first argument must be a scalar. Zero is considered false, any other number is considered true.

The second argument is a keyword list with do and else blocks. The sides are broadcast to return the same shape and normalized to return the same type.

examples

Examples

if Nx.any(Nx.equal(t, 0)) do
  0.0
else
  1 / t
end

In case else is not given, it is assumed to be 0 with the same as the do clause. If you want to nest multiple conditionals, see cond/1 instead.

Link to this macro

import(module, opts \\ [])

View Source (macro)

Imports functions and macros into the current scope, as in Kernel.SpecialForms.import/2.

Imports are typically discouraged in favor of alias/2.

examples

Examples

defn some_fun(t) do
  import Math.Helpers
  fft(t)
end
Link to this function

inspect_expr(expr, opts \\ [])

View Source

Inspects the given expression to the terminal.

It returns the given expressions.

examples

Examples

defn tanh_grad(t) do
  grad(t, &Nx.tanh/1) |> inspect_expr()
end

When invoked, it will print the expression being built by defn:

#Nx.Tensor<
  Nx.Defn.Expr
  parameter a s64
  parameter c s64
  b = tanh [ a ] f64
  d = power [ c, 2 ] s64
  e = add [ b, d ] f64
>
Link to this function

inspect_value(expr, opts \\ [])

View Source

Inspects the value at runtime to the terminal.

This function is implemented on top of hook/3 and therefore has the following restrictions:

  • It can only inspect tensors and Nx.Container
  • The return value of this function must be part of the output

All options are passed to IO.inspect/2.

examples

Examples

defn tanh_grad(t) do
  grad(t, fn t ->
    t
    |> Nx.tanh()
    |> inspect_value()
  end)
end

defn tanh_grad(t) do
  grad(t, fn t ->
    t
    |> Nx.tanh()
    |> inspect_value(label: "tanh")
  end)
end
Link to this function

keyword!(keyword, values)

View Source

Ensures the first argument is a keyword with the given keys and default values.

The second argument must be a list of atoms, specifying a given key, or tuples specifying a key and a default value. If any of the keys in the keyword is not defined on values, it raises an error.

examples

Examples

iex> keyword!([], [one: 1, two: 2]) |> Enum.sort()
[one: 1, two: 2]

iex> keyword!([two: 3], [one: 1, two: 2]) |> Enum.sort()
[one: 1, two: 3]

If atoms are given, they are supported as keys but do not provide a default value:

iex> keyword!([], [:one, two: 2]) |> Enum.sort()
[two: 2]

iex> keyword!([one: 1], [:one, two: 2]) |> Enum.sort()
[one: 1, two: 2]

Passing an unknown key raises:

iex> keyword!([three: 3], [one: 1, two: 2])
** (ArgumentError) unknown key :three in [three: 3], expected one of [:one, :two]

Element-wise maximum operation.

It delegates to Nx.max/2 (supports broadcasting).

examples

Examples

defn min_max(a, b) do
  {min(a, b), max(a, b)}
end

Element-wise minimum operation.

It delegates to Nx.min/2 (supports broadcasting).

examples

Examples

defn min_max(a, b) do
  {min(a, b), max(a, b)}
end

Element-wise logical NOT operation.

Zero is considered false, all other numbers are considered true.

It delegates to Nx.logical_not/1.

examples

Examples

defn logical_not(a), do: not a

Element-wise logical OR operation.

Zero is considered false, all other numbers are considered true.

It delegates to Nx.logical_or/2 (supports broadcasting).

examples

Examples

defn and_or(a, b) do
  {a and b, a or b}
end

Element-wise remainder operation.

It delegates to Nx.remainder/2 (supports broadcasting).

examples

Examples

defn divides_by_5?(a) do
  rem(a, 5)
  |> Nx.any()
  |> Nx.equal(Nx.tensor(1))
end
Link to this macro

require(module, opts \\ [])

View Source (macro)

Requires a module in order to use its macros, as in Kernel.SpecialForms.require/2.

examples

Examples

defn some_fun(t) do
  require NumericalMacros

  NumericalMacros.some_macro t do
    ...
  end
end
Link to this function

rewrite_types(expr, opts)

View Source

Rewrites the types of expr recursively according to opts

options

Options

  • :max_unsigned_type - replaces all signed tensors with size equal to or greater then the given type by the given type

  • :max_signed_type - replaces all signed tensors with size equal to or greater then the given type by the given type

  • :max_float_type - replaces all float tensors with size equal to or greater then the given type by the given type

examples

Examples

rewrite_types(expr, max_float_type: {:f, 32})

Stops computing the gradient for the given expression.

It effectively annotates the gradient for the given expression is 1.0.

examples

Examples

expr = stop_grad(expr)
Link to this macro

tap(value, fun)

View Source (macro)

Pipes value to the given fun and returns the value itself.

Useful for running synchronous side effects in a pipeline.

examples

Examples

Let's suppose you want to inspect an expression in the middle of a pipeline. You could write:

a
|> Nx.add(b)
|> tap(&inspect_expr/1)
|> Nx.multiply(c)
Link to this macro

then(value, fun)

View Source (macro)

Pipes value into the given fun.

In other words, it invokes fun with value as argument. This is most commonly used in pipelines, allowing you to pipe a value to a function outside of its first argument.

examples

Examples

a
|> Nx.add(b)
|> then(&Nx.subtract(c, &1))

Defines a transform that executes the given fun with arg when building defn expressions.

example

Example

Take the following defn expression:

defn tanh_power(a, b) do
  Nx.tanh(a) + Nx.power(b, 2)
end

Let's see a trivial example, which is to use IO.inspect/1 to print a tensor expression at definition time:

defn tanh_power(a, b) do
  Nx.tanh(a) + Nx.power(b, 2) |> transform(&IO.inspect/1)
end

Or:

defn tanh_power(a, b) do
  res = Nx.tanh(a) + Nx.power(b, 2)
  transform(res, &IO.inspect/1)
  res
end

When invoked in both cases, it will print the expression being built by defn:

#Nx.Defn.Expr<
  parameter a
  parameter c
  b = tanh [ a ] ()
  d = power [ c, 2 ] ()
  e = add [ b, d ] ()
>

Although, for convenience, you might use inspect_expr/2 instead.

pitfalls

Pitfalls

Because transform/2 is invoked inside defn, its scope is tied to defn. For example, if you do this:

transform(tensor, fn tensor ->
  if Nx.type(tensor) != {:f, 32} do
    raise "bad"
  end
end)

it won't work because it will use the != operator defined in this module, which only works with tensors, instead of the operator defined in Elixir's Kernel. Therefore, we recommend all transform/2 calls to simply dispatch to a separate function. The example above could be rewritten as:

transform(tensor, &assert_2x2_shape(&1))

where:

defp assert_2x2_shape(tensor) do
  if Nx.shape(tensor) != {2, 2} do
    raise "bad"
  end
end
Link to this macro

while(initial, condition, other)

View Source (macro)

Defines a while loop.

It expects the initial arguments, a condition expression, and a block:

while initial, condition do
  block
end

condition must return a scalar tensor where 0 is false and any other number is true. The given block will be executed while condition is true. Each invocation of block must return a value in the same shape as initial arguments.

while will return the value of the last execution of block. If block is never executed because the initial condition is false, it returns initial.

examples

Examples

A simple loop that increments x until it is 10 can be written as:

while x = 0, Nx.less(x, 10) do
  x + 1
end

However, it is important to note that all variables you intend to use inside the "while" must be explicitly given as argument to "while". For example, imagine the amount we want to increment by in the example above is given by a variable y. The following example is invalid:

while x = 0, Nx.less(x, 10) do
  x + y
end

Instead, both x and y must be passed as variables to while:

while {x = 0, y}, Nx.less(x, 10) do
  {x + y, y}
end

Similarly, to compute the factorial of x using while:

  defn factorial(x) do
    {factorial, _} =
      while {factorial = 1, x}, Nx.greater(x, 1) do
        {factorial * x, x - 1}
      end

    factorial
  end

Pipes the argument on the left to the function call on the right.

It delegates to Kernel.|>/2.

examples

Examples

defn exp_sum(t) do
  t
  |> Nx.exp()
  |> Nx.sum()
end

Element-wise bitwise OR operation.

Only integer tensors are supported. It delegates to Nx.bitwise_or/2 (supports broadcasting).

examples

Examples

defn and_or(a, b) do
  {a &&& b, a ||| b}
end

Element-wise bitwise not operation.

Only integer tensors are supported. It delegates to Nx.bitwise_not/1.

examples

Examples

defn bnot(a), do: ~~~a