View Source Nx.Container protocol (Nx v0.4.1)
A protocol that teaches Nx how to traverse data structures non-recursively.
Nx
and defn
expect the arguments to be numbers, tensors,
or one of the following composite data types:
- tuples of numbers/tensors
- maps of any key with numbers/tensors as values
- any struct that implements
Nx.Container
If you need to pass additional values, you can implement or derive this protocol. For example:
@derive {Nx.Container,
containers: [:field_name, :other_field]}
defstruct [:field_name, :other_fields, ...]
The :containers
option is required and it must specify a
list of fields that contains tensors. Inside defn
, the
container fields will be automatically converted to tensor
expressions. All other fields will be reset to their default
value, unless you explicitly declare them to be kept:
@derive {Nx.Container,
containers: [:field_name, :other_field],
keep: [:another_field]}
defstruct [:field_name, :other_fields, ...]
Note the functions in this module are not recursive.
If you want to deeply traverse and reduce containers,
use the functions in Nx.Defn.Composite
instead.
Careful!: If you keep a field, its value will be part of the
Nx.Defn
compiler cache key (i.e. therefore if you give a struct with two different values for a kept field,Nx.Defn
will have to compile and cache it twice). You must only keep fields that you are certain to be used insidedefn
during compilation time.
Link to this section Summary
Functions
Reduces non-recursively tensors in a data structure with acc
and fun
.
Traverses non-recursively tensors in a data structure with acc
and fun
.
Link to this section Types
@type t() :: term()
Link to this section Functions
Reduces non-recursively tensors in a data structure with acc
and fun
.
fun
is invoked with each tensor or tensor container in the
data structure plus an accumulator. It must return the new
accumulator.
This function the final accumulator.
Given fun
may receive containers, it is not recursive by default.
See Nx.Defn.Composite.reduce/3
for a recursive variant.
Traverses non-recursively tensors in a data structure with acc
and fun
.
fun
is invoked with each tensor or tensor container in the
data structure plus an accumulator. It must return a two element
tuple with the updated value and accumulator.
This function returns the updated container and the accumulator.
Given fun
may receive containers, it is not recursive by default.
See Nx.Defn.Composite.traverse/3
for a recursive variant.