EMLXAxon.Qwen3.Generate (emlx_axon v0.3.0)

Copy Markdown View Source

Autoregressive token generation loop.

EMLX.eval is called once per token, at the sampler boundary, so the full 28-layer forward runs as a single lazy MLX graph before any CPU synchronisation.

Usage

{:ok, state}     = Loader.load(model_path)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-0.6B"})

encoded   = Bumblebee.apply_tokenizer(tokenizer, "Hello")
input_ids = encoded["input_ids"]

{tokens, timing} = Generate.generate(input_ids, state,
  max_new_tokens: 100,
  sampler: :greedy
)

Summary

Functions

Run the autoregressive generation loop.

Functions

generate(input_ids, state, opts \\ [])

Run the autoregressive generation loop.

Options

  • :max_new_tokens — max number of tokens to generate (default 100)
  • :max_len — KV cache allocation size (default 2048); ignored when :kv_cache is given
  • :kv_cache — pre-allocated KV cache from Model.init_kv_cache/2; if provided,
                    `Model.init_kv_cache/2` is skipped. The cache is used as-is  callers
                    are responsible for ensuring it is clean (stale K/V beyond `current_len`
                    is never read because `Model.forward/4` slices to the valid prefix).
  • :sampler:greedy | :top_p_cpu | :top_p_gpu (default :greedy)

  • :temperature — float, passed to samplers that use it (default 0.95)
  • :top_p — float, passed to nucleus samplers (default 0.9)
  • :rng_keyNx.Random.key/1, used by :top_p_gpu (split each step via
                    `Nx.Random.split/2`; avoids host time + transfer per token)
  • :profile_timing — when true (default), record per_token_ms decode samples via
                    `System.monotonic_time/1` each step; set `false` to skip that overhead
                    (prefill/total wall time is still measured)

Returns

{generated_token_ids, %{timing: timing_map}} where timing_map has:

  • :prefill_ms — first-token time in milliseconds
  • :per_token_ms — list of per-token decode times (median ≈ steady-state)
  • :total_ms — wall time for the whole call