EMLX.Quantization (emlx v0.3.0)

Copy Markdown View Source

Affine group-wise int2/int4/int8 quantization for Apple Silicon inference.

Quantized weights are represented as annotated Nx.Tensor values — the tensor carries the original logical shape and type (e.g. {:s, 4} for 4-bit), while the EMLX.Backend struct stores the packed uint32 data and a EMLX.Quantization.Config with scales, biases, group_size, and bits.

Nx.dot automatically dispatches to mx::quantized_matmul when it detects a quantized operand on EMLX.Backend — no explicit call site changes needed.

Basic usage

# Quantize a dense weight
weight = Nx.iota({512, 4096}, type: :f32)
weight = Nx.backend_transfer(weight, {EMLX.Backend, device: :gpu})
qw = EMLX.Quantization.from_dense(weight)

# Standard Nx.dot dispatches to mx::quantized_matmul automatically
input = Nx.iota({1, 8, 4096}, type: :f32)
result = Nx.dot(input, [2], qw, [1])

Inside defn

quantize/2, dequantize/1, and quantized_matmul/2 are all deftransform functions backed by Nx.runtime_call, so they are safe to call inside Nx.Defn.jit-traced forward passes:

defn my_layer(x, qw) do
  dense = EMLX.Quantization.dequantize(qw)
  Nx.dot(x, [2], dense, [1])
end

Nx.dot with a quantized tensor also dispatches transparently via EMLX.Backend.dot/7 at execution time, so explicit dequantize is only needed when you want the dense weight for a non-dot operation.

See also

  • EMLX.Quantization.Config — internal metadata struct

Summary

Functions

Dequantize a quantized Nx.Tensor via Nx.runtime_call.

Quantize a dense 2-D tensor via Nx.runtime_call.

Returns true if the tensor has quantization metadata on its backend.

Run a quantized matmul via Nx.runtime_call.

Construct an annotated quantized Nx.Tensor from pre-computed device refs.

Functions

dequantize(qw)

Dequantize a quantized Nx.Tensor via Nx.runtime_call.

At execution time the callback receives the real tensor, unpacks quantization_config, and calls mx::dequantize. Safe to call inside Nx.Defn.jit-traced forward passes.

The output has the same shape as the input (the quantized tensor's logical shape equals the dense shape).

quantize(tensor, opts \\ [])

Quantize a dense 2-D tensor via Nx.runtime_call.

At execution time the callback receives the real tensor and runs mx::quantize. Returns an annotated quantized Nx.Tensor with the same logical shape as the input and type {:s, N}. Safe to call inside Nx.Defn.jit-traced forward passes.

Options

  • :type — Nx storage type: {:s, 2}, {:s, 4} (default), or {:s, 8}.
  • :group_size — 32, 64, or 128 (default 64). Must evenly divide the last dimension of tensor.

quantized?(arg1)

@spec quantized?(term()) :: boolean()

Returns true if the tensor has quantization metadata on its backend.

quantized_matmul(activation, qw)

Run a quantized matmul via Nx.runtime_call.

At execution time the callback unpacks quantization_config from qw and calls mx::quantized_matmul. Safe to call inside Nx.Defn.jit-traced forward passes.

Output shape is {batch_dims..., out_features} where out_features is the first dimension of qw. Output type is :bf16.

quantized_tensor(weight_ref, scales_ref, biases_ref, original_shape, opts \\ [])

@spec quantized_tensor(term(), term(), term(), tuple(), keyword()) :: Nx.Tensor.t()

Construct an annotated quantized Nx.Tensor from pre-computed device refs.

Use this when you already have packed weights from a checkpoint. For quantizing a dense tensor from scratch, prefer from_dense/2.

Options

  • :type — Nx storage type: {:s, 2}, {:s, 4} (default), or {:s, 8}.
  • :group_size — quantization group size (default 64).