View Source Nx.Random (Nx v0.7.3)

Pseudo-random number generators.

Unlike the stateful pseudo-random number generators (PRNGs) that users of most programming languages and numerical libraries may be accustomed to, Nx random functions require an explicit PRNG key to be passed as a first argument (see below for more info). That key is defined by an Nx.Tensor composed of 2 unsigned 32-bit integers, usually generated by the Nx.Random.key/1 function:

iex> Nx.Random.key(12)
#Nx.Tensor<
  u32[2]
  [0, 12]
>

This key can then be used in any of Nx’s random number generation routines:

iex> key = Nx.Random.key(12)
iex> {uniform, _new_key} = Nx.Random.uniform(key)
iex> uniform
#Nx.Tensor<
  f32
  0.7691127061843872
>

Now, when generating a new random number, you pass the new_key to get a different number.

The function in this module also have a *_split variant, which is used when the key has been split before hand.

Design and Context

In short, Nx's PRNGs are based on a Threefry counter PRNG associated to a functional array-oriented splitting model. To summarize, among other requirements, Nx's PRNG aims to:

  1. Ensure reproducibility

  2. Parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

The key to understanding Nx.Random keys

Most Elixir users might be used to not having to keep track of the PRNG state while their code executes.

While this works fine when we're dealing with the CPU, we can think of keeping track of the Nx.Random key as a way to isolate multiple GPU users, much like the PRNG on different BEAM nodes is isolated. Each key gets updated in its own isolated sequence of calls, and thus we don't get different results for each process using the same PRNG as we would in the normal situation.

The fact that the key is a parameter for the functions also helps with the caching and operator fusion of the computational graphs. Because the PRNG functions themselves are stateless, compilers can take advantage of this to further improve execution times.

Summary

Functions

Generates random samples from a tensor.

Generates random samples from a tensor with specified probabilities.

Folds in new data to a PRNG key.

Sample Gumbel random values with given shape and float dtype.

Same as gumbel/2, but assumes the key has been split beforehand.

Create a pseudo-random number generator (PRNG) key given an integer seed.

Returns a sample from a multivariate normal distribution with given mean and covariance (matrix). The function assumes that the covariance is a positive semi-definite matrix. Otherwise, the result will not be normally distributed.

Same as multivariate_normal/4 but assumes the key has already been split.

Shortcut for normal(key, 0.0, 1.0, opts).

Returns a normal distribution with the given mean and standard_deviation.

Same as normal/4 but assumes the key has already been split.

Sample uniform random integer values in the semi-open open interval [min_value, max_value).

Same as randint/4 but assumes the key has already been split.

Randomly shuffles tensor elements along an axis.

Splits a PRNG key into num new keys by adding a leading axis.

Shortcut for uniform(key, 0.0, 1.0, opts).

Sample uniform float values in the semi-open interval [min_val, max_val).

Same as uniform/4 but assumes the key has already been split.

Functions

Link to this function

choice(key, tensor, opts \\ [])

View Source

Generates random samples from a tensor.

Options

  • :samples - The number of samples to take

  • :axis - The axis along which to take samples. If nil, the tensor is flattened beforehand.

  • :replace - a boolean that specifies if samples will be taken with or without replacement. Defaults to true.

Examples

iex> k = Nx.Random.key(1)
iex> t = Nx.iota({4, 3})
iex> {result, _key} = Nx.Random.choice(k, t, samples: 4, axis: 0)
iex> result
#Nx.Tensor<
  s64[4][3]
  [
    [6, 7, 8],
    [3, 4, 5],
    [6, 7, 8],
    [3, 4, 5]
  ]
>
iex> {result, _key} = Nx.Random.choice(k, t, samples: 4, axis: 0, replace: false)
iex> result
#Nx.Tensor<
  s64[4][3]
  [
    [3, 4, 5],
    [9, 10, 11],
    [6, 7, 8],
    [0, 1, 2]
  ]
>

If no axis is specified, the tensor is flattened:

iex> k = Nx.Random.key(2)
iex> t = Nx.iota({3, 2})
iex> {result, _key} = Nx.Random.choice(k, t)
iex> result
#Nx.Tensor<
  s64[1]
  [3]
>
iex> {result, _key} = Nx.Random.choice(k, t, samples: 6, replace: false)
iex> result
#Nx.Tensor<
  s64[6]
  [2, 0, 4, 5, 1, 3]
>
Link to this function

choice(key, tensor, p, opts)

View Source

Generates random samples from a tensor with specified probabilities.

The probabilities tensor must have the same size as the axis along which the samples are being taken. If no axis is given, the size must be equal to the input tensor's size.

Options

  • :samples - The number of samples to take

  • :axis - The axis along which to take samples. If nil, the tensor is flattened beforehand.

  • :replace - a boolean that specifies if samples will be taken with or without replacement. Defaults to true.

Examples

iex> k = Nx.Random.key(1)
iex> t = Nx.iota({4, 3})
iex> p = Nx.tensor([0.1, 0.7, 0.2])
iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 3, axis: 1)
iex> result
#Nx.Tensor<
  s64[4][3]
  [
    [1, 0, 1],
    [4, 3, 4],
    [7, 6, 7],
    [10, 9, 10]
  ]
>
iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 3, axis: 1, replace: false)
iex> result
#Nx.Tensor<
  s64[4][3]
  [
    [1, 2, 0],
    [4, 5, 3],
    [7, 8, 6],
    [10, 11, 9]
  ]
>

If no axis is specified, the tensor is flattened. Notice that in the first case we get a higher occurence of the entries with bigger probabilities, while in the second case, without replacements, we get those samples first.

iex> k = Nx.Random.key(2)
iex> t = Nx.iota({2, 3})
iex> p = Nx.tensor([0.01, 0.1, 0.19, 0.6, 0.05, 0.05])
iex> {result, _key} = Nx.Random.choice(k, t, p)
iex> result
#Nx.Tensor<
  s64[1]
  [3]
>
iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 6)
iex> result
#Nx.Tensor<
  s64[6]
  [3, 3, 3, 0, 3, 3]
>
iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 6, replace: false)
iex> result
#Nx.Tensor<
  s64[6]
  [3, 1, 2, 5, 4, 0]
>

Folds in new data to a PRNG key.

Examples

iex> key = Nx.Random.key(42)
iex> Nx.Random.fold_in(key, 99)
#Nx.Tensor<
  u32[2]
  [2015327502, 1351855566]
>

iex> key = Nx.Random.key(42)
iex> Nx.Random.fold_in(key, 1234)
#Nx.Tensor<
  u32[2]
  [1356445167, 2917756949]
>

iex> key = Nx.Random.key(42)
iex> Nx.Random.fold_in(key, Nx.tensor([[1, 99], [1234, 13]]))
#Nx.Tensor<
  u32[2][2][2]
  [
    [
      [64467757, 2916123636],
      [2015327502, 1351855566]
    ],
    [
      [1356445167, 2917756949],
      [3514951389, 229662949]
    ]
  ]
>

Sample Gumbel random values with given shape and float dtype.

Options

  • :shape - the shape of the output tensor containing the random samples. Defaults to {}

  • :type - the floating-point output type. Defaults to {:f, 32}

Examples

iex> {result, _key} = Nx.Random.gumbel(Nx.Random.key(1))
iex> result
#Nx.Tensor<
  f32
  -0.7294610142707825
>

iex> {result, _key} = Nx.Random.gumbel(Nx.Random.key(1), shape: {2, 3})
iex> result
#Nx.Tensor<
  f32[2][3]
  [
    [0.6247938275337219, -0.21740718185901642, 0.7678327560424805],
    [0.7778404355049133, 4.0895304679870605, 0.3029090166091919]
  ]
>
Link to this function

gumbel_split(key, opts \\ [])

View Source

Same as gumbel/2, but assumes the key has been split beforehand.

Create a pseudo-random number generator (PRNG) key given an integer seed.

Examples

iex> Nx.Random.key(12)
#Nx.Tensor<
  u32[2]
  [0, 12]
>

iex> Nx.Random.key(999999999999)
#Nx.Tensor<
  u32[2]
  [232, 3567587327]
>
Link to this function

multivariate_normal(key, mean, covariance, opts \\ [])

View Source

Returns a sample from a multivariate normal distribution with given mean and covariance (matrix). The function assumes that the covariance is a positive semi-definite matrix. Otherwise, the result will not be normally distributed.

Options

  • :type - a float type for the returned tensor

  • :shape - batch shape of the returned tensor, i.e. the prefix of the result shape excluding the last axis

  • :names - the names of the returned tensor

  • :method - a decomposition method used for the covariance. Must be one of :svd, :eigh, and :cholesky. Defaults to :cholesky. For singular covariance matrices, use :svd or :eigh.

Examples

iex> key = Nx.Random.key(12)
iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0]), Nx.tensor([[1]]))
iex> multivariate_normal
#Nx.Tensor<
  f32[1]
  [0.735927939414978]
>

iex> key = Nx.Random.key(12)
iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0, 0]), Nx.tensor([[1, 0], [0, 1]]))
iex> multivariate_normal
#Nx.Tensor<
  f32[2]
  [-1.3425945043563843, -0.40812060236930847]
>

iex> key = Nx.Random.key(12)
iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0]), Nx.tensor([[1]]), shape: {3, 2}, type: :f16)
iex> multivariate_normal
#Nx.Tensor<
  f16[3][2][1]
  [
    [
      [0.326904296875],
      [0.2176513671875]
    ],
    [
      [0.316650390625],
      [0.1109619140625]
    ],
    [
      [0.53955078125],
      [-0.8857421875]
    ]
  ]
>

iex> key = Nx.Random.key(12)
iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0, 0]), Nx.tensor([[1, 0], [0, 1]]), shape: {3, 2})
iex> multivariate_normal
#Nx.Tensor<
  f32[3][2][2]
  [
    [
      [0.9891449809074402, 1.0795185565948486],
      [-0.9467806220054626, 1.47813880443573]
    ],
    [
      [2.2095863819122314, -1.529456377029419],
      [-0.7933920621871948, 1.121195673942566]
    ],
    [
      [0.10976295918226242, -0.9959557056427002],
      [0.4754556119441986, 1.1413804292678833]
    ]
  ]
>
Link to this function

multivariate_normal_split(key, mean, covariance, opts \\ [])

View Source

Same as multivariate_normal/4 but assumes the key has already been split.

Shortcut for normal(key, 0.0, 1.0, opts).

Link to this function

normal(key, mean, standard_deviation, opts \\ [])

View Source

Returns a normal distribution with the given mean and standard_deviation.

Options

  • :type - a float or complex type for the returned tensor

  • :shape - shape of the returned tensor

  • :names - the names of the returned tensor

Examples

iex> key = Nx.Random.key(42)
iex> {normal, _new_key} = Nx.Random.normal(key)
iex> normal
#Nx.Tensor<
  f32
  1.3694695234298706
>

iex> key = Nx.Random.key(42)
iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {3, 2}, type: :f16)
iex> normal
#Nx.Tensor<
  f16[3][2]
  [
    [-0.32568359375, -0.77197265625],
    [0.39208984375, 0.5341796875],
    [0.270751953125, -2.080078125]
  ]
>

iex> key = Nx.Random.key(42)
iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {2, 2}, type: :c64)
iex> normal
#Nx.Tensor<
  c64[2][2]
  [
    [-0.7632761001586914+0.8661127686500549i, -0.14282889664173126-0.7384796142578125i],
    [0.678461492061615+0.4118310809135437i, -2.269538402557373-0.3689095079898834i]
  ]
>

iex> key = Nx.Random.key(1337)
iex> {normal, _new_key} = Nx.Random.normal(key, 10, 5, shape: {1_000})
iex> Nx.mean(normal)
#Nx.Tensor<
  f32
  9.70022201538086
>
iex> Nx.standard_deviation(normal)
#Nx.Tensor<
  f32
  5.051416397094727
>
Link to this function

normal_split(key, mean, standard_deviation, opts \\ [])

View Source

Same as normal/4 but assumes the key has already been split.

Link to this function

randint(key, min_val, max_val, opts \\ [])

View Source

Sample uniform random integer values in the semi-open open interval [min_value, max_value).

Options

  • :type - the integer type for the returned tensor
  • :shape - shape of the returned tensor
  • :names - the names of the returned tensor

Examples

iex> key = Nx.Random.key(1701)
iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100)
iex> randint
#Nx.Tensor<
  s64
  66
>

iex> key = Nx.Random.key(1701)
iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100, shape: {3, 2}, type: :u32)
iex> randint
#Nx.Tensor<
  u32[3][2]
  [
    [9, 20],
    [19, 6],
    [71, 15]
  ]
>
Link to this function

randint_split(key, min_val, max_val, opts \\ [])

View Source

Same as randint/4 but assumes the key has already been split.

Link to this function

shuffle(key, tensor, opts \\ [])

View Source

Randomly shuffles tensor elements along an axis.

Options

  • :axis - the axis along which to shuffle. Defaults to 0

  • :independent - a boolean that indicates wether the permutations are independent along the given axis. Defaults to false

Examples

iex> key = Nx.Random.key(42)
iex> {shuffled, _new_key} = Nx.Random.shuffle(key, Nx.iota({3, 4}, axis: 0))
iex> shuffled
#Nx.Tensor<
  s64[3][4]
  [
    [2, 2, 2, 2],
    [0, 0, 0, 0],
    [1, 1, 1, 1]
  ]
>

iex> key = Nx.Random.key(10)
iex> {shuffled, _new_key} = Nx.Random.shuffle(key, Nx.iota({3, 4}, axis: 1), independent: true, axis: 1)
iex> shuffled
#Nx.Tensor<
  s64[3][4]
  [
    [2, 1, 3, 0],
    [3, 0, 1, 2],
    [2, 3, 0, 1]
  ]
>

Splits a PRNG key into num new keys by adding a leading axis.

Examples

iex> key = Nx.Random.key(1701)
iex> Nx.Random.split(key)
#Nx.Tensor<
  u32[2][2]
  [
    [56197195, 1801093307],
    [961309823, 1704866707]
  ]
>

iex> key = Nx.Random.key(999999999999)
iex> Nx.Random.split(key, parts: 4)
#Nx.Tensor<
  u32[4][2]
  [
    [3959978897, 4079927650],
    [3769699049, 3585271160],
    [3182829676, 333122445],
    [3185556048, 1258545461]
  ]
>
Link to this function

threefry2x32_20_concat(xs, ks)

View Source
Link to this function

uniform(key, opts \\ [])

View Source

Shortcut for uniform(key, 0.0, 1.0, opts).

Link to this function

uniform(key, min_val, max_val, opts \\ [])

View Source

Sample uniform float values in the semi-open interval [min_val, max_val).

Options

  • :type - a float type for the returned tensor

  • :shape - shape of the returned tensor

  • :names - the names of the returned tensor

Examples

iex> key = Nx.Random.key(1701)
iex> {uniform, _new_key} = Nx.Random.uniform(key)
iex> uniform
#Nx.Tensor<
  f32
  0.9728643894195557
>

iex> key = Nx.Random.key(1701)
iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {3, 2}, type: :f16)
iex> uniform
#Nx.Tensor<
  f16[3][2]
  [
    [0.75390625, 0.6484375],
    [0.7294921875, 0.21484375],
    [0.09765625, 0.0693359375]
  ]
>

iex> key = Nx.Random.key(1701)
iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {2, 2}, type: :c64)
iex> uniform
#Nx.Tensor<
  c64[2][2]
  [
    [0.18404805660247803+0.6546461582183838i, 0.5525915622711182+0.11568140983581543i],
    [0.6074584722518921+0.8104375600814819i, 0.247686505317688+0.21975469589233398i]
  ]
>
Link to this function

uniform_split(key, min_value, max_value, opts \\ [])

View Source

Same as uniform/4 but assumes the key has already been split.