View Source Nx.Container protocol (Nx v0.5.3)
A protocol that teaches defn
how to traverse data structures.
When you invoke a defn
, its arguments must implement
a Nx.LazyContainer
and return a data structure that
implements Nx.Container
. Inside defn
, you can work
with any container data structure, such as:
- numbers/tensors
- tuples
- maps of any key
- any struct that implements
Nx.Container
The easiest way to implement Nx.Container
is by deriving
it. For example:
@derive {Nx.Container,
containers: [:field_name, :other_field]}
defstruct [:field_name, :other_fields, ...]
The :containers
option is required and it must specify a
list of fields that contains tensors (or other containers).
Inside defn
, the container fields will be automatically
converted to tensor expressions. All other fields will be
reset to their default value, unless you explicitly declare
them to be kept:
@derive {Nx.Container,
containers: [:field_name, :other_field],
keep: [:another_field]}
defstruct [:field_name, :other_fields, ...]
Note Nx.LazyContainer
is automatically implemented for all
data structures that implement Nx.Container
. This also means
that you can convert any Nx.Container
to a tensor by using
Nx.stack/2
and Nx.concatenate/2
.
Careful!: If you keep a field, its value will be part of the
Nx.Defn
compiler cache key (i.e. therefore if you give a struct with two different values for a kept field,Nx.Defn
will have to compile and cache it twice). You must only keep fields that you are certain to be used insidedefn
during compilation time.
serialization
Serialization
If you @derive {Nx.Container, ...}
, it will automatically
define a serialization function with the container and keep
fields you declare. If you expect a struct to be serialized,
then you must be careful to evolve its schema over time in
a compatible way. In particular, removing fields will lead to
crashes. If you change the type of a field value, previously
serialized structs may still hold the old type. And if you add
new fields, previously serialized structs won't have such fields
and therefore be deserialized with its default value.
Link to this section Summary
Functions
Reduces non-recursively tensors in a data structure with acc
and fun
.
Defines how this container must be serialized to disk.
Traverses non-recursively tensors in a data structure with acc
and fun
.
Link to this section Types
@type t() :: term()
All the types that implement this protocol.
Link to this section Functions
Reduces non-recursively tensors in a data structure with acc
and fun
.
fun
is invoked with each tensor or tensor container in the
data structure plus an accumulator. It must return the new
accumulator.
This function the final accumulator.
Given fun
may receive containers, it is not recursive by default.
See Nx.Defn.Composite.reduce/3
for a recursive variant.
Defines how this container must be serialized to disk.
It receives the container and it must return a three element tuple
of {module, list_of_container_tuples, metadata}
where:
- the
module
to deserialize the container - a list of tuples in the shape
{key, container}
with containers to be further serialized - additional metadata for serialization/deserialization
On deserialization, module.deserialize(list_of_container_tuples, metadata)
will be invoked.
Traverses non-recursively tensors in a data structure with acc
and fun
.
fun
is invoked with each tensor or tensor container in the
data structure plus an accumulator. It must return a two element
tuple with the updated value and accumulator.
This function returns the updated container and the accumulator.
Given fun
may receive containers, it is not recursive by default.
See Nx.Defn.Composite.traverse/3
for a recursive variant.