View Source Torchx (Torchx v0.9.0)

Bindings and Nx integration for PyTorch.

Torchx provides an Nx backend through Torchx.Backend, which allows for integration with both the CPU and GPU functionality that PyTorch provides. To enable Torchx as the default backend you can add the following line to your desired config environment (config/config.exs, config/test.exs, etc):

import Config
config :nx, :default_backend, Torchx.Backend

This will ensure that by default all tensors are created PyTorch tensors. It's important to keep in mind that the default device is the CPU. If you wish to allocated tensors to the GPU by default, you can pass the :device option to the config line, as follows:

import Config
config :nx, :default_backend, {Torchx.Backend, device: :cuda}

The device_available?/1 function can be used to determine whether :cuda is available. If you have CUDA installed but it doesn't show as available, check out the Installation README section.

Types

Torchx implements specific names for PyTorch types, which have Nx counterparts as in the following table:

Nx TypeTorchx TypeDescription
{:u, 8}:byteUnsigned 8-bit integer
{:s, 8}:charSigned 8-bit integer
{:s, 16}:shortSigned 16-bit integer
{:s, 32}:intSigned 32-bit integer
{:s, 64}:longSigned 64-bit integer
{:bf, 16}:brain16-bit brain floating-point number
{:f, 8}:float8_e5m28-bit floating-point number
{:f, 16}:half16-bit floating-point number
{:f, 32}:float32-bit floating-point number
{:f, 64}:double64-bit floating-point number
{:c, 64}:complex64-bit complex number, with two 32-bit float components
{:c, 128}:complex_double128-bit complex number, with two 64-bit float components

Devices

PyTorch implements a variety of devices, which can be seen below.

  • :cpu
  • :cuda
  • :mkldnn
  • :opengl
  • :opencl
  • :ideep
  • :hip
  • :fpga
  • :msnpu
  • :xla
  • :vulkan
  • :metal
  • :xpu
  • :mps

Summary

Functions

Returns the default device.

Check if device of the given type is available for Torchx.

Return devices quantity for the given device type.

Gets a Torchx tensor from a Nx tensor.

Converts a Torchx tensor to a Nx tensor.

Functions

Link to this function

all(tensor, axes, keep_axes)

View Source
Link to this function

all_close(tensor_a, tensor_b, rtol, atol, equal_nan)

View Source
Link to this function

amax(tensor, axes, keep_axes)

View Source
Link to this function

amin(tensor, axes, keep_axes)

View Source
Link to this function

any(tensor, axes, keep_axes)

View Source
Link to this function

arange(from, to, step, type, device)

View Source
Link to this function

arange(from, to, step, type, device, shape)

View Source
Link to this function

argmax(tensor, axis, keep_axes)

View Source
Link to this function

argmin(tensor, axis, keep_axes)

View Source
Link to this function

argsort(tensor, axis, is_descending, stable)

View Source
Link to this function

bitwise_and(tensorA, tensorB)

View Source
Link to this function

bitwise_or(tensorA, tensorB)

View Source
Link to this function

bitwise_xor(tensorA, tensorB)

View Source
Link to this function

broadcast_to(tensor, shape)

View Source
Link to this function

clip(tensor, tensor_min, tensor_max)

View Source
Link to this function

concatenate(tensors, axis)

View Source
Link to this function

conv(tensor_input, tensor_kernel, strides, padding, dilation, transposed, groups)

View Source
Link to this function

cumulative_max(tensor, axis)

View Source
Link to this function

cumulative_min(tensor, axis)

View Source
Link to this function

cumulative_product(tensor, axis)

View Source
Link to this function

cumulative_sum(tensor, axis)

View Source

Returns the default device.

Here is the priority in the order of availability:

  • :cuda
  • :cpu

The default can also be set (albeit not recommended) via the application environment by setting the :default_device option under the :torchx application.

Link to this function

device_available?(device)

View Source

Check if device of the given type is available for Torchx.

You can currently check the availability of:

  • :cuda
  • :mps
  • :cpu

Return devices quantity for the given device type.

You can check the device count of :cuda for now.

Link to this function

divide(tensorA, tensorB)

View Source
Link to this function

fft2(tensor, lengths, axes)

View Source
Link to this function

fft(tensor, length, axis)

View Source
Link to this function

from_blob(blob, shape, type, device)

View Source

Gets a Torchx tensor from a Nx tensor.

Link to this function

full(shape, scalar, type, device)

View Source
Link to this function

gather(tensor_input, tensor_indices, axis)

View Source
Link to this function

greater(tensorA, tensorB)

View Source
Link to this function

greater_equal(tensorA, tensorB)

View Source
Link to this function

ifft2(tensor, lengths, axes)

View Source
Link to this function

ifft(tensor, length, axis)

View Source
Link to this function

index(tensor_input, tensor_indices)

View Source
Link to this function

index_put(tensor_input, tensor_indices, tensor_updates, accumulate)

View Source
Link to this macro

is_tensor(dev, ref)

View Source (macro)
Link to this function

left_shift(tensorA, tensorB)

View Source
Link to this function

less_equal(tensorA, tensorB)

View Source
Link to this function

logical_and(tensorA, tensorB)

View Source
Link to this function

logical_or(tensorA, tensorB)

View Source
Link to this function

logical_xor(tensorA, tensorB)

View Source
Link to this function

matmul(tensorA, tensorB)

View Source
Link to this function

max_pool_3d(tensor_input, kernel_size, strides, padding, dilation)

View Source
Link to this function

multiply(tensorA, tensorB)

View Source
Link to this function

normal(mu, sigma, shape, type, device)

View Source
Link to this function

not_equal(tensorA, tensorB)

View Source
Link to this function

ones(shape, type, device)

View Source
Link to this function

pad(tensor, tensor_scalar, config)

View Source
Link to this function

product(tensor, axes, keep_axes)

View Source
Link to this function

put(tensor_input, index, tensor_source)

View Source
Link to this function

quotient(tensorA, tensorB)

View Source
Link to this function

rand(min, max, shape, type, device)

View Source
Link to this function

randint(min, max, shape, type, device)

View Source
Link to this function

remainder(tensorA, tensorB)

View Source
Link to this function

right_shift(tensorA, tensorB)

View Source
Link to this function

scalar_tensor(scalar, type, device)

View Source
Link to this function

slice(tensor, starts, lengths, strides)

View Source
Link to this function

solve(tensor_a, tensor_b)

View Source
Link to this function

sort(tensor, axis, descending, stable)

View Source
Link to this function

split(tensor, split_size)

View Source
Link to this function

subtract(tensorA, tensorB)

View Source
Link to this function

sum(tensor, axes, keep_axes)

View Source
Link to this function

svd(tensor, full_matrices)

View Source
Link to this function

tensordot(tensorA, tensorB, axesA, axesB)

View Source
Link to this function

tensordot(tensorA, tensorB, axesA, batchA, axesB, batchB)

View Source
Link to this function

to_device(tensor, device)

View Source

Converts a Torchx tensor to a Nx tensor.

Link to this function

transpose(tensor, dim0, dim1)

View Source
Link to this function

triangular_solve(tensor_a, tensor_b, transpose, upper)

View Source
Link to this function

unfold(tensor, dimension, size, step)

View Source
Link to this function

where(tensorA, tensorB, tensorC)

View Source