ExBurn.Tensor (ex_burn v0.1.0)

Copy Markdown View Source

Tensor conversion utilities between Nx and Burn formats.

Handles marshaling of tensor data between Elixir's Nx tensor representation and the Rust/Burn tensor references used by the NIF.

Type Mapping

Nx TypeBurn Type
{:f, 32}:f32
{:f, 64}:f64
{:f, 16}:f32
{:bf, 16}:f32
{:s, 32}:i32
{:s, 64}:i64
{:s, 16}:i32
{:s, 8}:i32
{:u, 8}:f32

Summary

Types

Burn element type tag

t()

NIF tensor reference

Functions

Converts a Burn type tag back to an Nx type tuple.

Frees the underlying Rust tensor.

Creates a Burn tensor from raw binary data, shape, and type.

Converts an Nx.Tensor.t() into an ExBurn.Tensor.t().

Batch converts a list of Nx tensors to Burn tensors.

Returns the total number of elements.

Converts an Nx type tuple to a Burn type tag.

Returns the rank (number of dimensions).

Returns the NIF reference for a tensor.

Returns the shape of a Burn tensor.

Converts an ExBurn.Tensor.t() back into an Nx.Tensor.t().

Batch converts a list of Burn tensors to Nx tensors.

Returns the Burn element type of a tensor.

Types

burn_type()

@type burn_type() :: :f32 | :f64 | :i32 | :i64

Burn element type tag

t()

@type t() :: %ExBurn.Tensor{
  ref: reference(),
  shape: [non_neg_integer()],
  type: burn_type()
}

NIF tensor reference

Functions

burn_type_to_nx(type)

@spec burn_type_to_nx(burn_type()) :: Nx.Type.t()

Converts a Burn type tag back to an Nx type tuple.

free(tensor)

@spec free(t()) :: :ok

Frees the underlying Rust tensor.

from_binary(data, shape, type)

@spec from_binary(binary(), [non_neg_integer()], burn_type()) ::
  {:ok, t()} | {:error, String.t()}

Creates a Burn tensor from raw binary data, shape, and type.

from_nx(tensor)

@spec from_nx(Nx.Tensor.t()) :: {:ok, t()} | {:error, String.t()}

Converts an Nx.Tensor.t() into an ExBurn.Tensor.t().

The tensor data is sent to the Rust NIF layer as a flat binary. Returns {:ok, tensor} or {:error, reason}.

from_nx_batch(tensors)

@spec from_nx_batch([Nx.Tensor.t()]) :: {:ok, [t()]} | {:error, String.t()}

Batch converts a list of Nx tensors to Burn tensors.

More efficient than calling from_nx/1 individually when you need to convert many tensors.

numel(tensor)

@spec numel(t()) :: non_neg_integer()

Returns the total number of elements.

nx_type_to_burn(type)

@spec nx_type_to_burn(Nx.Type.t()) :: burn_type()

Converts an Nx type tuple to a Burn type tag.

rank(tensor)

@spec rank(t()) :: non_neg_integer()

Returns the rank (number of dimensions).

ref(tensor)

@spec ref(t()) :: reference()

Returns the NIF reference for a tensor.

shape(tensor)

@spec shape(t()) :: [non_neg_integer()]

Returns the shape of a Burn tensor.

to_nx(tensor)

@spec to_nx(t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}

Converts an ExBurn.Tensor.t() back into an Nx.Tensor.t().

Reads the raw data from the Rust NIF layer and reshapes it. Returns {:ok, tensor} or {:error, reason}.

to_nx_batch(tensors)

@spec to_nx_batch([t()]) :: {:ok, [Nx.Tensor.t()]} | {:error, String.t()}

Batch converts a list of Burn tensors to Nx tensors.

type(tensor)

@spec type(t()) :: burn_type()

Returns the Burn element type of a tensor.