Torchx (Torchx v0.8.0)

Bindings and Nx integration for PyTorch.

Torchx provides an Nx backend through Torchx.Backend, which allows for integration with both the CPU and GPU functionality that PyTorch provides. To enable Torchx as the default backend you can add the following line to your desired config environment (config/config.exs, config/test.exs, etc):

import Config
config :nx, :default_backend, Torchx.Backend

This will ensure that by default all tensors are created PyTorch tensors. It's important to keep in mind that the default device is the CPU. If you wish to allocated tensors to the GPU by default, you can pass the :device option to the config line, as follows:

import Config
config :nx, :default_backend, {Torchx.Backend, device: :cuda}

The device_available?/1 function can be used to determine whether :cuda is available. If you have CUDA installed but it doesn't show as available, check out the Installation README section.


Torchx implements specific names for PyTorch types, which have Nx counterparts as in the following table:

Nx TypeTorchx TypeDescription
{:u, 8}:byteUnsigned 8-bit integer
{:s, 8}:charSigned 8-bit integer
{:s, 16}:shortSigned 16-bit integer
{:s, 32}:intSigned 32-bit integer
{:s, 64}:longSigned 64-bit integer
{:bf, 16}:brain16-bit brain floating-point number
{:f, 16}:half16-bit floating-point number
{:f, 32}:float32-bit floating-point number
{:f, 64}:double64-bit floating-point number
{:c, 64}:complex64-bit complex number, with two 32-bit float components
{:c, 128}:complex_double128-bit complex number, with two 64-bit float components


PyTorch implements a variety of devices, which can be seen below.

  • :cpu
  • :cuda
  • :mkldnn
  • :opengl
  • :opencl
  • :ideep
  • :hip
  • :fpga
  • :msnpu
  • :xla
  • :vulkan
  • :metal
  • :xpu
  • :mps



all(tensor, axes, keep_axes)

all_close(tensor_a, tensor_b, rtol, atol, equal_nan)

amax(tensor, axes, keep_axes)

amin(tensor, axes, keep_axes)

any(tensor, axes, keep_axes)

arange(from, to, step, type, device)

arange(from, to, step, type, device, shape)

argmax(tensor, axis, keep_axes)

argmin(tensor, axis, keep_axes)

argsort(tensor, axis, is_descending, stable)

bitwise_and(tensorA, tensorB)

bitwise_or(tensorA, tensorB)

bitwise_xor(tensorA, tensorB)

broadcast_to(tensor, shape)

clip(tensor, tensor_min, tensor_max)

concatenate(tensors, axis)

conv(tensor_input, tensor_kernel, strides, padding, dilation, transposed, groups)

cumulative_max(tensor, axis)

cumulative_min(tensor, axis)

cumulative_product(tensor, axis)

cumulative_sum(tensor, axis)

Returns the default device.

Here is the priority in the order of availability:

  • :cuda
  • :cpu

The default can also be set (albeit not recommended) via the application environment by setting the :default_device option under the :torchx application.

Check if device of the given type is available for Torchx.

You can currently check the availability of:

  • :cuda
  • :mps
  • :cpu

Return devices quantity for the given device type.

You can check the device count of :cuda for now.

divide(tensorA, tensorB)

fft2(tensor, lengths, axes)

fft(tensor, length, axis)

from_blob(blob, shape, type, device)

Gets a Torchx tensor from a Nx tensor.

full(shape, scalar, type, device)

gather(tensor_input, tensor_indices, axis)

greater(tensorA, tensorB)

greater_equal(tensorA, tensorB)

ifft2(tensor, lengths, axes)

ifft(tensor, length, axis)

index(tensor_input, tensor_indices)

index_put(tensor_input, tensor_indices, tensor_updates, accumulate)

is_tensor(dev, ref)

left_shift(tensorA, tensorB)

less_equal(tensorA, tensorB)

logical_and(tensorA, tensorB)

logical_or(tensorA, tensorB)

logical_xor(tensorA, tensorB)

matmul(tensorA, tensorB)

max_pool_3d(tensor_input, kernel_size, strides, padding, dilation)

multiply(tensorA, tensorB)

normal(mu, sigma, shape, type, device)

not_equal(tensorA, tensorB)

ones(shape, type, device)

pad(tensor, tensor_scalar, config)

product(tensor, axes, keep_axes)

put(tensor_input, index, tensor_source)

quotient(tensorA, tensorB)

rand(min, max, shape, type, device)

randint(min, max, shape, type, device)

remainder(tensorA, tensorB)

right_shift(tensorA, tensorB)

scalar_tensor(scalar, type, device)

slice(tensor, starts, lengths, strides)

solve(tensor_a, tensor_b)

sort(tensor, axis, descending, stable)

split(tensor, split_size)

subtract(tensorA, tensorB)

sum(tensor, axes, keep_axes)

svd(tensor, full_matrices)

tensordot(tensorA, tensorB, axesA, axesB)

tensordot(tensorA, tensorB, axesA, batchA, axesB, batchB)

to_device(tensor, device)

Converts a Torchx tensor to a Nx tensor.

transpose(tensor, dim0, dim1)

triangular_solve(tensor_a, tensor_b, transpose, upper)

unfold(tensor, dimension, size, step)

where(tensorA, tensorB, tensorC)

