View Source Nx.Container protocol (Nx v0.8.0)

A protocol that teaches defn how to traverse data structures.

When you invoke a defn, its arguments must implement a Nx.LazyContainer and return a data structure that implements Nx.Container. Inside defn, you can work with any container data structure, such as:

  1. numbers/tensors
  2. tuples
  3. maps of any key
  4. any struct that implements Nx.Container

In other words, LazyContainer is how you convert data structures that are not meant to work inside defn into a Nx.Container. And a Nx.Container is a data structure that can be manipulated inside defn itself.

The easiest way to implement Nx.Container is by deriving it. For example:

@derive {Nx.Container,
         containers: [:field_name, :other_field]}
defstruct [:field_name, :other_fields, ...]

The :containers option is required and it must specify a list of fields that contains tensors (or other containers). Inside defn, the container fields will be automatically converted to tensor expressions. All other fields will be reset to their default value, unless you explicitly declare them to be kept:

@derive {Nx.Container,
         containers: [:field_name, :other_field],
         keep: [:another_field]}
defstruct [:field_name, :other_fields, ...]

Note Nx.LazyContainer is automatically provided for all data structures that implement Nx.Container.

Careful!: If you keep a field, its value will be part of the Nx.Defn compiler cache key (i.e. therefore if you give a struct with two different values for a kept field, Nx.Defn will have to compile and cache it twice). You must only keep fields that you are certain to be used inside defn during compilation time.

Serialization

If you @derive {Nx.Container, ...}, it will automatically define a serialization function with the container and keep fields you declare. If you expect a struct to be serialized, then you must be careful to evolve its schema over time in a compatible way. In particular, removing fields will lead to crashes. If you change the type of a field value, previously serialized structs may still hold the old type. And if you add new fields, previously serialized structs won't have such fields and therefore be deserialized with its default value.

Summary

Types

t()

All the types that implement this protocol.

Functions

Reduces non-recursively tensors in a data structure with acc and fun.

Defines how this container must be serialized to disk.

Traverses non-recursively tensors in a data structure with acc and fun.

Types

@type t() :: term()

All the types that implement this protocol.

Functions

@spec reduce(t(), acc, (t(), acc -> acc)) :: acc when acc: term()

Reduces non-recursively tensors in a data structure with acc and fun.

fun is invoked with each tensor or tensor container in the data structure plus an accumulator. It must return the new accumulator.

This function the final accumulator.

Given fun may receive containers, it is not recursive by default. See Nx.Defn.Composite.reduce/3 for a recursive variant.

@spec serialize(t()) :: {module(), [{term(), t()}], term()}

Defines how this container must be serialized to disk.

It receives the container and it must return a three element tuple of {module, list_of_container_tuples, metadata} where:

  • the module to deserialize the container
  • a list of tuples in the shape {key, container} with containers to be further serialized
  • additional metadata for serialization/deserialization

On deserialization, module.deserialize(list_of_container_tuples, metadata) will be invoked.

Link to this function

traverse(data, acc, fun)

View Source
@spec traverse(t(), acc, (t(), acc -> {t(), acc})) :: {t(), acc} when acc: term()

Traverses non-recursively tensors in a data structure with acc and fun.

fun is invoked with each tensor or tensor container in the data structure plus an accumulator. It must return a two element tuple with the updated value and accumulator.

This function returns the updated container and the accumulator.

Given fun may receive containers, it is not recursive by default. See Nx.Defn.Composite.traverse/3 for a recursive variant.