View Source Nx.Random (Nx v0.4.1)

Pseudo-random number generators.

Unlike the stateful pseudo-random number generators (PRNGs) that users of most programming languages and numerical libraries may be accustomed to, Nx random functions require an explicit PRNG key to be passed as a first argument. That key is defined by an Nx.Tensor composed of 2 unsigned 32-bit integers, usually generated by the Nx.Random.key/1 function:

iex> Nx.Random.key(12)
#Nx.Tensor<
  u32[2]
  [0, 12]
>

This key can then be used in any of Nx’s random number generation routines:

iex> key = Nx.Random.key(12)
iex> {uniform, _new_key} = Nx.Random.uniform(key)
iex> uniform
#Nx.Tensor<
  f32
  0.7691127061843872
>

Now, when generating a new random number, you pass the new_key to get a different number.

The function in this module also have a *_split variant, which is used when the key has been split before hand.

design-and-context

Design and Context

In short, Nx's PRNGs are based on a Threefry counter PRNG associated to a functional array-oriented splitting model. To summarize, among other requirements, Nx's PRNG aims to:

  1. Ensure reproducibility

  2. Parallelize well, both in terms of vectorization (generating array values) and multi-replica, multi-core computation. In particular it should not use sequencing constraints between random function calls.

Link to this section Summary

Functions

Folds in new data to a PRNG key.

Create a pseudo-random number generator (PRNG) key given an integer seed.

Shortcut for normal(key, 0.0, 1.0, opts).

Returns a normal distribution with the given mean and standard_deviation.

Same as normal/4 but assumes the key has already been split.

Sample uniform random integer values in [min_value, max_value).

Same as randint/4 but assumes the key has already been split.

Splits a PRNG key into num new keys by adding a leading axis.

Shortcut for uniform(key, 0.0, 1.0, opts).

Sample uniform float values in [min_val, max_val).

Same as uniform/4 but assumes the key has already been split.

Link to this section Functions

Folds in new data to a PRNG key.

examples

Examples

iex> key = Nx.Random.key(42)
iex> Nx.Random.fold_in(key, 99)
#Nx.Tensor<
  u32[2]
  [2015327502, 1351855566]
>

iex> key = Nx.Random.key(42)
iex> Nx.Random.fold_in(key, 1234)
#Nx.Tensor<
  u32[2]
  [1356445167, 2917756949]
>

iex> key = Nx.Random.key(42)
iex> Nx.Random.fold_in(key, Nx.tensor([[1, 99], [1234, 13]]))
#Nx.Tensor<
  u32[2][2][2]
  [
    [
      [64467757, 2916123636],
      [2015327502, 1351855566]
    ],
    [
      [1356445167, 2917756949],
      [3514951389, 229662949]
    ]
  ]
>

Create a pseudo-random number generator (PRNG) key given an integer seed.

examples

Examples

iex> Nx.Random.key(12)
#Nx.Tensor<
  u32[2]
  [0, 12]
>

iex> Nx.Random.key(999999999999)
#Nx.Tensor<
  u32[2]
  [232, 3567587327]
>

Shortcut for normal(key, 0.0, 1.0, opts).

Link to this function

normal(key, mean, standard_deviation, opts \\ [])

View Source

Returns a normal distribution with the given mean and standard_deviation.

options

Options

  • :type - a float or complex type for the returned tensor

  • :shape - shape of the returned tensor

  • :names - the names of the returned tensor

examples

Examples

iex> key = Nx.Random.key(42)
iex> {normal, _new_key} = Nx.Random.normal(key)
iex> normal
#Nx.Tensor<
  f32
  1.3694695234298706
>

iex> key = Nx.Random.key(42)
iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {3, 2}, type: :f16)
iex> normal
#Nx.Tensor<
  f16[3][2]
  [
    [-0.32568359375, -0.77197265625],
    [0.39208984375, 0.5341796875],
    [0.270751953125, -2.080078125]
  ]
>

iex> key = Nx.Random.key(42)
iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {2, 2}, type: :c64)
iex> normal
#Nx.Tensor<
  c64[2][2]
  [
    [-0.7632761001586914+0.8661127686500549i, -0.14282889664173126-0.7384796142578125i],
    [0.678461492061615+0.4118310809135437i, -2.269538402557373-0.3689095079898834i]
  ]
>

iex> key = Nx.Random.key(1337)
iex> {normal, _new_key} = Nx.Random.normal(key, 10, 5, shape: {1_000})
iex> Nx.mean(normal)
#Nx.Tensor<
  f32
  9.70022201538086
>
iex> Nx.standard_deviation(normal)
#Nx.Tensor<
  f32
  5.051416397094727
>
Link to this function

normal_split(key, mean, standard_deviation, opts \\ [])

View Source

Same as normal/4 but assumes the key has already been split.

Link to this function

randint(key, min_val, max_val, opts \\ [])

View Source

Sample uniform random integer values in [min_value, max_value).

options

Options

  • :type - the integer type for the returned tensor
  • :shape - shape of the returned tensor
  • :names - the names of the returned tensor

examples

Examples

iex> key = Nx.Random.key(1701)
iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100)
iex> randint
#Nx.Tensor<
  s64
  66
>

iex> key = Nx.Random.key(1701)
iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100, shape: {3, 2}, type: :u32)
iex> randint
#Nx.Tensor<
  u32[3][2]
  [
    [9, 20],
    [19, 6],
    [71, 15]
  ]
>
Link to this function

randint_split(key, min_val, max_val, opts \\ [])

View Source

Same as randint/4 but assumes the key has already been split.

Splits a PRNG key into num new keys by adding a leading axis.

examples

Examples

iex> key = Nx.Random.key(1701)
iex> Nx.Random.split(key)
#Nx.Tensor<
  u32[2][2]
  [
    [56197195, 1801093307],
    [961309823, 1704866707]
  ]
>

iex> key = Nx.Random.key(999999999999)
iex> Nx.Random.split(key, parts: 4)
#Nx.Tensor<
  u32[4][2]
  [
    [3959978897, 4079927650],
    [3769699049, 3585271160],
    [3182829676, 333122445],
    [3185556048, 1258545461]
  ]
>
Link to this function

threefry2x32_20_concat(xs, ks)

View Source
Link to this function

uniform(key, opts \\ [])

View Source

Shortcut for uniform(key, 0.0, 1.0, opts).

Link to this function

uniform(key, min_val, max_val, opts \\ [])

View Source

Sample uniform float values in [min_val, max_val).

options

Options

  • :type - a float type for the returned tensor

  • :shape - shape of the returned tensor

  • :names - the names of the returned tensor

examples

Examples

iex> key = Nx.Random.key(1701)
iex> {uniform, _new_key} = Nx.Random.uniform(key)
iex> uniform
#Nx.Tensor<
  f32
  0.9728643894195557
>

iex> key = Nx.Random.key(1701)
iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {3, 2}, type: :f16)
iex> uniform
#Nx.Tensor<
  f16[3][2]
  [
    [0.75390625, 0.6484375],
    [0.7294921875, 0.21484375],
    [0.09765625, 0.0693359375]
  ]
>

iex> key = Nx.Random.key(1701)
iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {2, 2}, type: :c64)
iex> uniform
#Nx.Tensor<
  c64[2][2]
  [
    [0.18404805660247803+0.6546461582183838i, 0.5525915622711182+0.11568140983581543i],
    [0.6074584722518921+0.8104375600814819i, 0.247686505317688+0.21975469589233398i]
  ]
>
Link to this function

uniform_split(key, min_value, max_value, opts \\ [])

View Source

Same as uniform/4 but assumes the key has already been split.