# `EMLXAxon.Qwen3.Sampler`
[🔗](https://github.com/elixir-nx/emlx/blob/v0.3.0/emlx_axon/lib/emlx_axon/qwen3/sampler.ex#L1)

Three sampling strategies for autoregressive token generation.

All functions accept `logits` of shape `{1, vocab_size}` and return a
scalar integer tensor (the sampled token id).

- `greedy/1`     — deterministic argmax, fastest
- `top_p_cpu/3`  — faithful port of bobby_posts: logits → CPU → sort → sample
- `top_p_gpu/2`  — `defn`-compiled top-p on the GPU; avoids the host transfer

# `greedy`

Greedy decoding: return the token with the highest logit as a scalar tensor.

# `top_p_cpu`

CPU top-p (nucleus) sampler — faithful port of bobby_posts.

Transfers the full vocabulary logits to the BEAM for sort + sample.
This matches the A0 baseline timing (expected ~42 ms overhead per token).

# `top_p_gpu`

GPU temperature sampler compiled with `defn`.

Uses the Gumbel-max trick: `argmax(logits/temp + Gumbel_noise)` draws a
sample proportional to `softmax(logits/temp)`. This is mathematically
equivalent to categorical sampling without any vocabulary-wide sort, which
avoids the MLX argsort kernel limitation at vocab_size=151936.

Note: this implements temperature sampling without nucleus (top-p) cutoff.
The top-p filtering requires argsort of a 151k-element tensor, which hits
an unsupported MLX Metal kernel for this vocab size. Temperature sampling
achieves similar randomisation and benchmarks the GPU sampler path.

`key` must be a `Nx.Random.key/1` value passed as a positional argument
(not through opts) so defn can trace it correctly.

## Options
- `:temperature` — float, default 0.95

---

*Consult [api-reference.md](api-reference.md) for complete listing*
