Three sampling strategies for autoregressive token generation.
All functions accept logits of shape {1, vocab_size} and return a
scalar integer tensor (the sampled token id).
greedy/1— deterministic argmax, fastesttop_p_cpu/3— faithful port of bobby_posts: logits → CPU → sort → sampletop_p_gpu/2—defn-compiled top-p on the GPU; avoids the host transfer
Summary
Functions
Greedy decoding: return the token with the highest logit as a scalar tensor.
CPU top-p (nucleus) sampler — faithful port of bobby_posts.
GPU temperature sampler compiled with defn.
Functions
Greedy decoding: return the token with the highest logit as a scalar tensor.
CPU top-p (nucleus) sampler — faithful port of bobby_posts.
Transfers the full vocabulary logits to the BEAM for sort + sample. This matches the A0 baseline timing (expected ~42 ms overhead per token).
GPU temperature sampler compiled with defn.
Uses the Gumbel-max trick: argmax(logits/temp + Gumbel_noise) draws a
sample proportional to softmax(logits/temp). This is mathematically
equivalent to categorical sampling without any vocabulary-wide sort, which
avoids the MLX argsort kernel limitation at vocab_size=151936.
Note: this implements temperature sampling without nucleus (top-p) cutoff. The top-p filtering requires argsort of a 151k-element tensor, which hits an unsupported MLX Metal kernel for this vocab size. Temperature sampling achieves similar randomisation and benchmarks the GPU sampler path.
key must be a Nx.Random.key/1 value passed as a positional argument
(not through opts) so defn can trace it correctly.
Options
:temperature— float, default 0.95