EMLXAxon.MLX4BitParams (emlx_axon v0.3.0)

Copy Markdown View Source

Loads Qwen3 weights from an MLX-4bit safetensors checkpoint into Bumblebee Axon params format (BF16, Bumblebee {in, out} key convention).

How it works

MLX-4bit checkpoints store linear weights in {out, in/8} packed u32 format (MLX's {out, in} convention, packing the in-dimension). Loading proceeds in two steps:

  1. Dequantize each weight to BF16 — yielding the logical shape {out, in}.
  2. Transpose to {in, out} — Bumblebee's Axon convention where Nx.dot/6 contracts the weight's first axis (in-features).

After calling load/1, pass the result through EMLXAxon.QuantizeParams.quantize/1 to re-apply 4-bit quantization in Bumblebee's convention so quantized_matmul dispatch is active at serving time.

Usage

{:ok, model_info} = Bumblebee.load_model({:local, mlx_path})
params = EMLXAxon.MLX4BitParams.load(mlx_path)
params = EMLXAxon.QuantizeParams.quantize(params)
serving = Bumblebee.Text.generation(
  %{model_info | params: params}, tokenizer, gen_cfg, ...
)

Summary

Functions

Load Qwen3 params from an MLX-4bit checkpoint directory.

Functions

load(mlx_path)

@spec load(Path.t()) :: Axon.ModelState.t()

Load Qwen3 params from an MLX-4bit checkpoint directory.

Returns %Axon.ModelState{} with BF16 tensors in Bumblebee key layout. All linear kernels are transposed from MLX {out, in} to Bumblebee {in, out}.