EMLXAxon.QuantizeParams (emlx_axon v0.3.0)

Copy Markdown View Source

Post-load param quantization for Bumblebee models.

Traverses a Bumblebee params map and quantizes eligible 2-D weight tensors to 4-bit so that Nx.dot dispatches to EMLX.quantized_matmul via the backend's transparent dispatch (A6-1 of the emlx#108 investigation).

Usage

{:ok, model_info} = Bumblebee.load_model(source, backend: {EMLX.Backend, device: :gpu})
model_info = %{model_info | params: EMLXAxon.QuantizeParams.quantize(model_info.params)}
model_info = %{model_info | model: EMLXAxon.rewrite(model_info.model)}

Eligibility

A tensor is quantized if ALL of the following hold:

  • rank is 2
  • first dimension (in_features) is divisible by group_size (default 64)
  • first dimension < skip_vocab_threshold (default 100_000) — skips embed_tokens / lm_head
  • both dimensions ≥ 2 * group_size

Summary

Functions

Traverse params and quantize all eligible weight tensors.

Functions

quantize(params, opts \\ [])

@spec quantize(
  map(),
  keyword()
) :: map()

Traverse params and quantize all eligible weight tensors.

Options

  • :bits — quantization bit-width, 4 (default) or 8.
  • :group_size — quantization group size, must evenly divide in_features (default 64).
  • :skip_vocab_threshold — skip tensors whose first dim exceeds this (default 100_000).