View Source Nx.Container protocol (Nx v0.4.0)

A protocol that teaches Nx how to traverse data structures.

Nx and defn expect the arguments to be numbers, tensors, or one of the following composite data types:

  1. tuples of numbers/tensors
  2. maps of any key with numbers/tensors as values
  3. any struct that implements Nx.Container

If you need to pass additional values, you can implement or derive this protocol. For example:

@derive {Nx.Container,
         containers: [:field_name, :other_field]}
defstruct [:field_name, :other_fields, ...]

The :containers option is required and it must specify a list of fields that contains tensors. Inside defn, the container fields will be automatically converted to tensor expressions. All other fields will be reset to their default value, unless you explicitly declare them to be kept:

@derive {Nx.Container,
         containers: [:field_name, :other_field],
         keep: [:another_field]}
defstruct [:field_name, :other_fields, ...]

Careful!: If you keep a field, its value will be part of the Nx.Defn compiler cache key (i.e. therefore if you give a struct with two different values for a kept field, Nx.Defn will have to compile and cache it twice). You must only keep fields that you are certain to be used inside defn during compilation time.

Link to this section Summary

Functions

Reduces non-recursively tensors in a data structure with acc and fun.

Traverses non-recursively tensors in a data structure with acc and fun.

Link to this section Types

Link to this section Functions

@spec reduce(t(), acc, (Nx.t() | t(), acc -> acc)) :: acc when acc: term()

Reduces non-recursively tensors in a data structure with acc and fun.

fun is invoked with each tensor or tensor container in the data structure plus an accumulator. It must return the new accumulator.

This function the final accumulator.

Given fun may receive containers, it is not recursive by default. See Nx.Defn.Composite.reduce/3 for a recursive variant.

Link to this function

traverse(data, acc, fun)

View Source
@spec traverse(t(), acc, (Nx.t() | t(), acc -> {Nx.t() | t(), acc})) :: acc
when acc: term()

Traverses non-recursively tensors in a data structure with acc and fun.

fun is invoked with each tensor or tensor container in the data structure plus an accumulator. It must return a two element tuple with the updated value and accumulator.

This function returns the updated container and the accumulator.

Given fun may receive containers, it is not recursive by default. See Nx.Defn.Composite.traverse/3 for a recursive variant.