EMLXAxon.TextGeneration (emlx_axon v0.3.0)

Copy Markdown View Source

A Nx.Serving-compatible wrapper around the native Qwen3 quantized model.

Bypasses the Axon graph entirely — the 28-layer forward pass runs as a single mlx::eval per token (via EMLXAxon.Qwen3.Generate), avoiding the 28 separate Metal command buffer submissions that the Bumblebee + Axon path incurs.

Only Bumblebee tokenization is used from upstream Bumblebee. No Bumblebee model function or Axon graph is involved in the decode forward pass.

Usage

{:ok, tokenizer} = Bumblebee.load_tokenizer({:local, "~/models/Qwen3-0.6B-MLX-4bit"})
serving = EMLXAxon.TextGeneration.from_mlx4bit(
  "~/models/Qwen3-0.6B-MLX-4bit",
  tokenizer,
  max_new_tokens: 100,
  sampler: :greedy
)

result = Nx.Serving.run(serving, "Write a short story about a robot who learns to love.")
IO.puts(result.results |> hd() |> Map.fetch!(:generated_text))

Summary

Functions

Convenience: load %State{} from an MLX-4bit checkpoint directory and build a serving.

Builds an Nx.Serving wrapping the native Qwen3 quantized model.

Functions

from_mlx4bit(checkpoint_path, tokenizer, opts \\ [])

@spec from_mlx4bit(Path.t(), Bumblebee.Tokenizer.t(), keyword()) :: Nx.Serving.t()

Convenience: load %State{} from an MLX-4bit checkpoint directory and build a serving.

The tokenizer is expected to come from the same directory (same tokenizer.json). Loading both from the same directory avoids chat-template / BOS-token divergence.

serving(tokenizer, state, opts \\ [])

Builds an Nx.Serving wrapping the native Qwen3 quantized model.

Accepts the same text-string input format as Bumblebee.Text.generation/4: a plain binary or %{text: binary()}. Returns %{results: [%{generated_text: binary(), num_tokens: pos_integer()}]} for a single input and a list of those maps for a batch input.

Options

  • :max_new_tokens — max tokens to generate per request (default 100)
  • :max_len — KV cache preallocated token budget (default 2048)
  • :sampler:greedy | :top_p_cpu | :top_p_gpu (default :greedy)

  • :profile_timing — forwarded to Generate.generate/3; when false, skips per-token
                    `System.monotonic_time` in the decode loop (default `true`)