Post-load param quantization for Bumblebee models.
Traverses a Bumblebee params map and quantizes eligible 2-D weight tensors to
4-bit so that Nx.dot dispatches to EMLX.quantized_matmul via the backend's
transparent dispatch (A6-1 of the emlx#108 investigation).
Usage
{:ok, model_info} = Bumblebee.load_model(source, backend: {EMLX.Backend, device: :gpu})
model_info = %{model_info | params: EMLXAxon.QuantizeParams.quantize(model_info.params)}
model_info = %{model_info | model: EMLXAxon.rewrite(model_info.model)}Eligibility
A tensor is quantized if ALL of the following hold:
- rank is 2
- first dimension (in_features) is divisible by
group_size(default 64) - first dimension <
skip_vocab_threshold(default 100_000) — skips embed_tokens / lm_head - both dimensions ≥
2 * group_size
Summary
Functions
Traverse params and quantize all eligible weight tensors.