Affine group-wise int2/int4/int8 quantization for Apple Silicon inference.
Quantized weights are represented as annotated Nx.Tensor values — the
tensor carries the original logical shape and type (e.g. {:s, 4} for
4-bit), while the EMLX.Backend struct stores the packed uint32 data and
a EMLX.Quantization.Config with scales, biases, group_size, and bits.
Nx.dot automatically dispatches to mx::quantized_matmul when it detects
a quantized operand on EMLX.Backend — no explicit call site changes needed.
Basic usage
# Quantize a dense weight
weight = Nx.iota({512, 4096}, type: :f32)
weight = Nx.backend_transfer(weight, {EMLX.Backend, device: :gpu})
qw = EMLX.Quantization.from_dense(weight)
# Standard Nx.dot dispatches to mx::quantized_matmul automatically
input = Nx.iota({1, 8, 4096}, type: :f32)
result = Nx.dot(input, [2], qw, [1])Inside defn
quantize/2, dequantize/1, and quantized_matmul/2 are all
deftransform functions backed by Nx.runtime_call, so they are safe
to call inside Nx.Defn.jit-traced forward passes:
defn my_layer(x, qw) do
dense = EMLX.Quantization.dequantize(qw)
Nx.dot(x, [2], dense, [1])
endNx.dot with a quantized tensor also dispatches transparently via
EMLX.Backend.dot/7 at execution time, so explicit dequantize is
only needed when you want the dense weight for a non-dot operation.
See also
EMLX.Quantization.Config— internal metadata struct
Summary
Functions
Dequantize a quantized Nx.Tensor via Nx.runtime_call.
At execution time the callback receives the real tensor, unpacks
quantization_config, and calls mx::dequantize. Safe to call inside
Nx.Defn.jit-traced forward passes.
The output has the same shape as the input (the quantized tensor's logical shape equals the dense shape).
Quantize a dense 2-D tensor via Nx.runtime_call.
At execution time the callback receives the real tensor and runs
mx::quantize. Returns an annotated quantized Nx.Tensor with the same
logical shape as the input and type {:s, N}. Safe to call inside
Nx.Defn.jit-traced forward passes.
Options
:type— Nx storage type:{:s, 2},{:s, 4}(default), or{:s, 8}.:group_size— 32, 64, or 128 (default 64). Must evenly divide the last dimension oftensor.
Returns true if the tensor has quantization metadata on its backend.
Run a quantized matmul via Nx.runtime_call.
At execution time the callback unpacks quantization_config from qw and
calls mx::quantized_matmul. Safe to call inside Nx.Defn.jit-traced
forward passes.
Output shape is {batch_dims..., out_features} where out_features is the
first dimension of qw. Output type is :bf16.
Construct an annotated quantized Nx.Tensor from pre-computed device refs.
Use this when you already have packed weights from a checkpoint. For
quantizing a dense tensor from scratch, prefer from_dense/2.
Options
:type— Nx storage type:{:s, 2},{:s, 4}(default), or{:s, 8}.:group_size— quantization group size (default 64).