View Source NumPy -> Nx
This cheatsheet is designed to assist Python developers in transitioning to Elixir, specifically by providing equivalent commands and code examples between NumPy and Nx.
Tensor Creation
From list or nested list
NumPy
>>> np.array([1, 2, 3])
array([1, 2, 3])
Nx
iex> Nx.tensor([1, 2, 3])
#Nx.Tensor<
s32[3]
[1, 2, 3]
>
2D Arrays/Tensors
NumPy
>>> np.array([[1, 2], [3, 4]])
array([[1, 2],
[3, 4]])
Nx
iex> Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
s32[2][2]
[
[1, 2],
[3, 4]
]
>
Zeros and Ones
NumPy
>>> np.zeros((2, 3))
array([[0., 0., 0.],
[0., 0., 0.]])
>>> np.ones((2, 3))
array([[1., 1., 1.],
[1., 1., 1.]])
NumPy
iex> Nx.broadcast(0, {2, 3})
#Nx.Tensor<
s32[2][3]
[
[0, 0, 0],
[0, 0, 0]
]
>
iex> Nx.broadcast(1, {2, 3})
#Nx.Tensor<
s32[2][3]
[
[1, 1, 1],
[1, 1, 1]
]
>
Range of Numbers
NumPy
>>> np.arange(0, 10, 2)
array([0, 2, 4, 6, 8])
NumPy
iex> Nx.iota({5}, axis: 0) |> Nx.multiply(2)
#Nx.Tensor<
s32[5]
[0, 2, 4, 6, 8]
>
Linearly Spaced Values
NumPy
>>> np.linspace(0, 1, 5)
array([0. , 0.25, 0.5 , 0.75, 1. ])
NumPy
iex> Nx.iota({5}) |> Nx.divide(4)
#Nx.Tensor<
f32[5]
[0.0, 0.25, 0.5, 0.75, 1.0]
>
Tensor Inspection
Shape
NumPy
>>> a = np.array([[1, 2, 3], [4, 5, 6]])
>>> a.shape
(2, 3)
Nx
iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
#Nx.Tensor<
s32[2][3]
[
[1, 2, 3],
[4, 5, 6]
]
>
iex> Nx.shape(a)
{2, 3}
Number of dimensions
NumPy
>>> a.ndim
2
Nx
iex> Nx.rank(a)
2
Data Type
NumPy
>>> a.dtype
dtype('int64')
Nx
iex> Nx.type(a)
{:s, 32}
Total Number of Elements
NumPy
>>> a.size
6
Nx
iex> Nx.size(a)
6
Indexing and Slicing
Indexing a Single Element
NumPy
>>> a = np.array([[10, 20], [30, 40]])
>>> a[0, 1]
np.int64(20)
Nx
# Indexing a Single Element
iex> tensor = Nx.tensor([[10, 20], [30, 40]])
iex> tensor[[0, 1]]
#Nx.Tensor<
s32
20
>
Slicing a Range
NumPy
>>> a = np.array([10, 20, 30, 40, 50])
>>> a[1:4]
array([20, 30, 40])
Nx
# Slicing a Range
iex> a = Nx.tensor([10, 20, 30, 40, 50])
iex> a[1..3]
#Nx.Tensor<
s32[3]
[20, 30, 40]
>
Selecting Along a Specific Axis
NumPy
>>> a = np.array([[1, 2], [3, 4], [5, 6]])
>>> a[:, 1]
array([2, 4, 6])
Nx
# Selecting Along a Specific Axis
iex> a = Nx.tensor([[1, 2], [3, 4], [5, 6]])
iex> a[[.., 1]]
#Nx.Tensor<
s32[3]
[2, 4, 6]
>
Boolean Masking
NumPy
>>> x = np.arange(10)
>>> x[x % 2 == 0]
array([0, 2, 4, 6, 8])
Nx
Boolean masking requires dynamic shape behavior, which is not supported in Nx because Nx compiles all operations ahead-of-time (like XLA or Jax), and for that, tensors must have static shapes.
Linear Algebra Operations
Matrix Multiplication
NumPy
>>> A = np.array([[1, 2], [3, 4]])
>>> B = np.array([[5, 6], [7, 8]])
>>> np.matmul(A, B)
array([[19, 22],
[43, 50]])
Nx
iex> a = Nx.tensor([[1, 2], [3, 4]])
iex> b = Nx.tensor([[5, 6], [7, 8]])
iex> Nx.dot(a, b)
#Nx.Tensor<
s32[2][2]
[
[19, 22],
[43, 50]
]
>
Transpose
NumPy
>>> A.T
array([[1, 3],
[2, 4]])
Nx
iex> Nx.transpose(a)
#Nx.Tensor<
s32[2][2]
[
[1, 3],
[2, 4]
]
>
Identity Matrix
NumPy
>>> np.eye(3)
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
Nx
iex> Nx.eye({3, 3})
#Nx.Tensor<
s32[3][3]
[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
]
>
Determinant
NumPy
>>> np.linalg.det(A)
np.float64(-2.0000000000000004)
Nx
iex> Nx.LinAlg.determinant(a)
#Nx.Tensor<
f32
-2.0
>
Inverse
NumPy
>>> np.linalg.inv(A)
array([[-2. , 1. ],
[ 1.5, -0.5]])
Nx
iex> Nx.LinAlg.invert(a)
#Nx.Tensor<
f32[2][2]
[
[-2.000000476837158, 1.0000003576278687],
[1.5000004768371582, -0.5000002384185791]
]
>
Solve a System of Linear Equations
NumPy
>>> A = np.array([[3, 1], [1, 2]])
>>> b = np.array([9, 8])
>>> np.linalg.solve(A, b)
array([2., 3.])
Nx
iex> a = Nx.tensor([[3.0, 1.0], [1.0, 2.0]])
iex> b = Nx.tensor([9.0, 8.0])
iex> Nx.LinAlg.solve(a, b)
#Nx.Tensor<
f32[2]
[2.0, 3.0]
>
Eigenvalues and Eigenvectors
NumPy
>>> np.linalg.eigh(A)
EighResult(
eigenvalues=array([1.38196601, 3.61803399]),
eigenvectors=array([
[ 0.52573111, -0.85065081],
[-0.85065081, -0.52573111]
]))
Nx
iex> Nx.LinAlg.eigh(a)
{#Nx.Tensor<
f32[2]
[3.618025779724121, 1.381974220275879]
>,
#Nx.Tensor<
f32[2][2]
[
[0.8516583442687988, -0.5240974426269531],
[0.5240974426269531, 0.8516583442687988]
]
>}