# `Nx.Random`
[🔗](https://github.com/elixir-nx/nx/blob/v0.11.0/nx/lib/nx/random.ex#L1)

Pseudo-random number generators.

Unlike the stateful pseudo-random number generators (PRNGs)
that users of most programming languages and numerical libraries
may be accustomed to, Nx random functions require an explicit
PRNG key to be passed as a first argument (see [below](#module-the-key-to-understanding-nx-random-keys) for more info). That key is defined by
an `Nx.Tensor` composed of 2 unsigned 32-bit integers, usually
generated by the `Nx.Random.key/1` function:

    iex> Nx.Random.key(12)
    #Nx.Tensor<
      u32[2]
      [0, 12]
    >

Or for example:

    iex> Nx.Random.key(System.os_time())

This key can then be used in any of Nx’s random number generation
routines:

    iex> key = Nx.Random.key(12)
    iex> {uniform, _new_key} = Nx.Random.uniform(key)
    iex> uniform
    #Nx.Tensor<
      f32
      0.7691127
    >

Now, when generating a new random number, you pass the `new_key`
to get a different number.

The function in this module also have a `*_split` variant, which
is used when the key has been split before hand.

## Design and Context

In short, Nx's PRNGs are based on a Threefry counter PRNG
associated to a functional array-oriented splitting model.
To summarize, among other requirements, Nx's PRNG aims to:

1. Ensure reproducibility

2. Parallelize well, both in terms of vectorization
   (generating array values) and multi-replica, multi-core
   computation. In particular it should not use sequencing
   constraints between random function calls.

> ### The key to understanding Nx.Random keys {: .info}
>
> Most Elixir users might be used to not having to keep
> track of the PRNG state while their code executes.
>
> While this works fine when we're dealing with the CPU,
> we can think of keeping track of the `Nx.Random` key as a way to
> isolate multiple GPU users, much like the PRNG on different
> BEAM nodes is isolated. Each key gets updated in its own
> isolated sequence of calls, and thus we don't get different
> results for each process using the same PRNG as we would
> in the normal situation.
>
> The fact that the key is a parameter for the functions also
> helps with the caching and operator fusion of the computational
> graphs. Because the PRNG functions themselves are stateless,
> compilers can take advantage of this to further improve execution times.

# `choice`

Generates random samples from a tensor.

## Options

  * `:samples` - The number of samples to take

  * `:axis` - The axis along which to take samples.
    If `nil`, the tensor is flattened beforehand.

  * `:replace` - a boolean that specifies if samples will
    be taken with or without replacement. Defaults to `true`.

## Examples

    iex> k = Nx.Random.key(1)
    iex> t = Nx.iota({4, 3})
    iex> {result, _key} = Nx.Random.choice(k, t, samples: 4, axis: 0)
    iex> result
    #Nx.Tensor<
      s32[4][3]
      [
        [6, 7, 8],
        [9, 10, 11],
        [6, 7, 8],
        [0, 1, 2]
      ]
    >
    iex> {result, _key} = Nx.Random.choice(k, t, samples: 4, axis: 0, replace: false)
    iex> result
    #Nx.Tensor<
      s32[4][3]
      [
        [3, 4, 5],
        [9, 10, 11],
        [6, 7, 8],
        [0, 1, 2]
      ]
    >

If no axis is specified, the tensor is flattened:

    iex> k = Nx.Random.key(2)
    iex> t = Nx.iota({3, 2})
    iex> {result, _key} = Nx.Random.choice(k, t)
    iex> result
    #Nx.Tensor<
      s32[1]
      [4]
    >
    iex> {result, _key} = Nx.Random.choice(k, t, samples: 6, replace: false)
    iex> result
    #Nx.Tensor<
      s32[6]
      [2, 0, 4, 5, 1, 3]
    >

# `choice`

Generates random samples from a tensor with specified probabilities.

The probabilities tensor must have the same size as the axis along
which the samples are being taken. If no axis is given, the size
must be equal to the input tensor's size.

## Options

  * `:samples` - The number of samples to take

  * `:axis` - The axis along which to take samples.
    If `nil`, the tensor is flattened beforehand.

  * `:replace` - a boolean that specifies if samples will
    be taken with or without replacement. Defaults to `true`.

## Examples

    iex> k = Nx.Random.key(1)
    iex> t = Nx.iota({4, 3})
    iex> p = Nx.tensor([0.1, 0.7, 0.2])
    iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 3, axis: 1)
    iex> result
    #Nx.Tensor<
      s32[4][3]
      [
        [1, 0, 1],
        [4, 3, 4],
        [7, 6, 7],
        [10, 9, 10]
      ]
    >
    iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 3, axis: 1, replace: false)
    iex> result
    #Nx.Tensor<
      s32[4][3]
      [
        [1, 2, 0],
        [4, 5, 3],
        [7, 8, 6],
        [10, 11, 9]
      ]
    >

If no axis is specified, the tensor is flattened.
Notice that in the first case we get a higher occurence
of the entries with bigger probabilities, while in the
second case, without replacements, we get those samples
first.

    iex> k = Nx.Random.key(2)
    iex> t = Nx.iota({2, 3})
    iex> p = Nx.tensor([0.01, 0.1, 0.19, 0.6, 0.05, 0.05])
    iex> {result, _key} = Nx.Random.choice(k, t, p)
    iex> result
    #Nx.Tensor<
      s32[1]
      [3]
    >
    iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 6)
    iex> result
    #Nx.Tensor<
      s32[6]
      [3, 3, 3, 0, 3, 3]
    >
    iex> {result, _key} = Nx.Random.choice(k, t, p, samples: 6, replace: false)
    iex> result
    #Nx.Tensor<
      s32[6]
      [3, 1, 2, 5, 4, 0]
    >

# `fold_in`

Folds in new data to a PRNG key.

## Examples

    iex> key = Nx.Random.key(42)
    iex> Nx.Random.fold_in(key, 99)
    #Nx.Tensor<
      u32[2]
      [2015327502, 1351855566]
    >

    iex> key = Nx.Random.key(42)
    iex> Nx.Random.fold_in(key, 1234)
    #Nx.Tensor<
      u32[2]
      [1356445167, 2917756949]
    >

    iex> key = Nx.Random.key(42)
    iex> Nx.Random.fold_in(key, Nx.tensor([[1, 99], [1234, 13]]))
    #Nx.Tensor<
      u32[2][2][2]
      [
        [
          [64467757, 2916123636],
          [2015327502, 1351855566]
        ],
        [
          [1356445167, 2917756949],
          [3514951389, 229662949]
        ]
      ]
    >

# `gumbel`

Sample Gumbel random values with given shape and float dtype.

## Options

  * `:shape` - the shape of the output tensor containing the
    random samples. Defaults to `{}`

  * `:type` - the floating-point output type. Defaults to `{:f, 32}`

## Examples

    iex> {result, _key} = Nx.Random.gumbel(Nx.Random.key(1))
    iex> result
    #Nx.Tensor<
      f32
      -0.729461
    >

    iex> {result, _key} = Nx.Random.gumbel(Nx.Random.key(1), shape: {2, 3})
    iex> result
    #Nx.Tensor<
      f32[2][3]
      [
        [0.6247938, -0.21740718, 0.76783276],
        [0.77784044, 4.0895305, 0.30290902]
      ]
    >

# `gumbel_split`

Same as `gumbel/2`, but assumes the key has been split beforehand.

# `key`

Create a pseudo-random number generator (PRNG) key given an integer seed.

## Examples

    iex> Nx.Random.key(12)
    #Nx.Tensor<
      u32[2]
      [0, 12]
    >

    iex> Nx.Random.key(999999999999)
    #Nx.Tensor<
      u32[2]
      [232, 3567587327]
    >

# `multivariate_normal`

Returns a sample from a multivariate normal distribution with given `mean` and `covariance` (matrix).
The function assumes that the covariance is a positive semi-definite matrix.
Otherwise, the result will not be normally distributed.

## Options

  * `:type` - a float type for the returned tensor

  * `:shape` - batch shape of the returned tensor, i.e. the prefix of the result shape excluding the last axis

  * `:names` - the names of the returned tensor

  * `:method` - a decomposition method used for the covariance. Must be one of :svd, :eigh, and :cholesky.
    Defaults to :cholesky. For singular covariance matrices, use :svd or :eigh.

## Examples

    iex> key = Nx.Random.key(12)
    iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0]), Nx.tensor([[1]]))
    iex> multivariate_normal
    #Nx.Tensor<
      f32[1]
      [0.73592794]
    >

    iex> key = Nx.Random.key(12)
    iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0, 0]), Nx.tensor([[1, 0], [0, 1]]))
    iex> multivariate_normal
    #Nx.Tensor<
      f32[2]
      [-1.3425945, -0.4081206]
    >

    iex> key = Nx.Random.key(12)
    iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0]), Nx.tensor([[1]]), shape: {3, 2}, type: :f16)
    iex> multivariate_normal
    #Nx.Tensor<
      f16[3][2][1]
      [
        [
          [0.327],
          [0.2177]
        ],
        [
          [0.3167],
          [0.11096]
        ],
        [
          [0.5396],
          [-0.8857]
        ]
      ]
    >

    iex> key = Nx.Random.key(12)
    iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0, 0]), Nx.tensor([[1, 0], [0, 1]]), shape: {3, 2})
    iex> multivariate_normal
    #Nx.Tensor<
      f32[3][2][2]
      [
        [
          [0.989145, 1.0795186],
          [-0.9467806, 1.4781388]
        ],
        [
          [2.2095864, -1.5294564],
          [-0.79339206, 1.1211957]
        ],
        [
          [0.10976296, -0.9959557],
          [0.4754556, 1.1413804]
        ]
      ]
    >

# `multivariate_normal_split`

Same as `multivariate_normal/4` but assumes the key has already been split.

# `normal`

Shortcut for `normal(key, 0.0, 1.0, opts)`.

# `normal`

Returns a normal distribution with the given `mean` and `standard_deviation`.

## Options

  * `:type` - a float or complex type for the returned tensor

  * `:shape` - shape of the returned tensor

  * `:names` - the names of the returned tensor

## Examples

    iex> key = Nx.Random.key(42)
    iex> {normal, _new_key} = Nx.Random.normal(key)
    iex> normal
    #Nx.Tensor<
      f32
      1.3694695
    >

    iex> key = Nx.Random.key(42)
    iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {3, 2}, type: :f16)
    iex> normal
    #Nx.Tensor<
      f16[3][2]
      [
        [-0.3257, -0.772],
        [0.392, 0.534],
        [0.2708, -2.08]
      ]
    >

    iex> key = Nx.Random.key(42)
    iex> {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {2, 2}, type: :c64)
    iex> normal
    #Nx.Tensor<
      c64[2][2]
      [
        [-0.7632761+0.86611277i, -0.1428289-0.7384796i],
        [0.6784615+0.41183108i, -2.2695384-0.3689095i]
      ]
    >

    iex> key = Nx.Random.key(1337)
    iex> {normal, _new_key} = Nx.Random.normal(key, 10, 5, shape: {1_000})
    iex> Nx.mean(normal)
    #Nx.Tensor<
      f32
      9.700222
    >
    iex> Nx.standard_deviation(normal)
    #Nx.Tensor<
      f32
      5.0514164
    >

# `normal_split`

Same as `normal/4` but assumes the key has already been split.

# `randint`

Sample uniform random integer values in the semi-open open interval `[min_value, max_value)`.

## Options

  * `:type` - the integer type for the returned tensor
  * `:shape` - shape of the returned tensor
  * `:names` - the names of the returned tensor

## Examples

    iex> key = Nx.Random.key(1701)
    iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100)
    iex> randint
    #Nx.Tensor<
      s32
      91
    >

    iex> key = Nx.Random.key(1701)
    iex> {randint, _new_key} = Nx.Random.randint(key, 1, 100, shape: {3, 2}, type: :u32)
    iex> randint
    #Nx.Tensor<
      u32[3][2]
      [
        [9, 20],
        [19, 6],
        [71, 15]
      ]
    >

# `randint_split`

Same as `randint/4` but assumes the key has already been split.

# `shuffle`

Randomly shuffles tensor elements along an axis.

## Options

  * `:axis` - the axis along which to shuffle. Defaults to `0`

  * `:independent` - a boolean that indicates whether the permutations
    are independent along the given axis. Defaults to `false`

## Examples

    iex> key = Nx.Random.key(42)
    iex> {shuffled, _new_key} = Nx.Random.shuffle(key, Nx.iota({3, 4}, axis: 0))
    iex> shuffled
    #Nx.Tensor<
      s32[3][4]
      [
        [2, 2, 2, 2],
        [0, 0, 0, 0],
        [1, 1, 1, 1]
      ]
    >

    iex> key = Nx.Random.key(10)
    iex> {shuffled, _new_key} = Nx.Random.shuffle(key, Nx.iota({3, 4}, axis: 1), independent: true, axis: 1)
    iex> shuffled
    #Nx.Tensor<
      s32[3][4]
      [
        [2, 1, 3, 0],
        [3, 0, 1, 2],
        [2, 3, 0, 1]
      ]
    >

# `split`

Splits a PRNG key into `num` new keys by adding a leading axis.

## Examples

    iex> key = Nx.Random.key(1701)
    iex> Nx.Random.split(key)
    #Nx.Tensor<
      u32[2][2]
      [
        [56197195, 1801093307],
        [961309823, 1704866707]
      ]
    >

    iex> key = Nx.Random.key(1701)
    iex> Nx.Random.split(key, parts: 4)
    #Nx.Tensor<
      u32[4][2]
      [
        [4000152724, 2030591747],
        [2287976877, 2598630646],
        [2426625787, 580268518],
        [3136765380, 433355682]
      ]
    >

# `threefry2x32_20_concat`

# `uniform`

Shortcut for `uniform(key, 0.0, 1.0, opts)`.

# `uniform`

Sample uniform float values in the semi-open interval `[min_val, max_val)`.

## Options

  * `:type` - a float type for the returned tensor

  * `:shape` - shape of the returned tensor

  * `:names` - the names of the returned tensor

## Examples

    iex> key = Nx.Random.key(1701)
    iex> {uniform, _new_key} = Nx.Random.uniform(key)
    iex> uniform
    #Nx.Tensor<
      f32
      0.9728644
    >

    iex> key = Nx.Random.key(1701)
    iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {3, 2}, type: :f16)
    iex> uniform
    #Nx.Tensor<
      f16[3][2]
      [
        [0.754, 0.6484],
        [0.7295, 0.2148],
        [0.09766, 0.06934]
      ]
    >

    iex> key = Nx.Random.key(1701)
    iex> {uniform, _new_key} = Nx.Random.uniform(key, shape: {2, 2}, type: :c64)
    iex> uniform
    #Nx.Tensor<
      c64[2][2]
      [
        [0.18404806+0.65464616i, 0.55259156+0.11568141i],
        [0.6074585+0.81043756i, 0.2476865+0.2197547i]
      ]
    >

# `uniform_split`

Same as `uniform/4` but assumes the key has already been split.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
