EMLXAxon.Qwen3.Sampler (emlx_axon v0.3.0)

Copy Markdown View Source

Three sampling strategies for autoregressive token generation.

All functions accept logits of shape {1, vocab_size} and return a scalar integer tensor (the sampled token id).

  • greedy/1 — deterministic argmax, fastest
  • top_p_cpu/3 — faithful port of bobby_posts: logits → CPU → sort → sample
  • top_p_gpu/2defn-compiled top-p on the GPU; avoids the host transfer

Summary

Functions

Greedy decoding: return the token with the highest logit as a scalar tensor.

CPU top-p (nucleus) sampler — faithful port of bobby_posts.

GPU temperature sampler compiled with defn.

Functions

greedy(logits)

Greedy decoding: return the token with the highest logit as a scalar tensor.

top_p_cpu(logits, temperature \\ 0.95, top_p \\ 0.9)

CPU top-p (nucleus) sampler — faithful port of bobby_posts.

Transfers the full vocabulary logits to the BEAM for sort + sample. This matches the A0 baseline timing (expected ~42 ms overhead per token).

top_p_gpu(logits, key, opts \\ [])

GPU temperature sampler compiled with defn.

Uses the Gumbel-max trick: argmax(logits/temp + Gumbel_noise) draws a sample proportional to softmax(logits/temp). This is mathematically equivalent to categorical sampling without any vocabulary-wide sort, which avoids the MLX argsort kernel limitation at vocab_size=151936.

Note: this implements temperature sampling without nucleus (top-p) cutoff. The top-p filtering requires argsort of a 151k-element tensor, which hits an unsupported MLX Metal kernel for this vocab size. Temperature sampling achieves similar randomisation and benchmarks the GPU sampler path.

key must be a Nx.Random.key/1 value passed as a positional argument (not through opts) so defn can trace it correctly.

Options

  • :temperature — float, default 0.95