viva_tensor/quant/awq

AWQ (Activation-aware Weight Quantization)

Reference: Lin et al. (2024) - “AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration” MLSys 2024 BEST PAPER AWARD https://arxiv.org/abs/2306.00978

— The Key Insight (worth repeating) — Only ~1% of weights are “salient” - and they matter 10x more than the rest. But here’s the twist: you identify them by looking at ACTIVATIONS, not weights. High activation magnitude = that channel matters = protect those weights.

— The Genius — Don’t modify the quantization algorithm. Modify the weights BEFORE quantizing. Scale salient channels UP by s, then scale activations DOWN by 1/s. Mathematically equivalent: WX = (sW)(X/s) But now the important weights have more precision after quantization.

— Compression Math — Same as NF4/INT4: 32/4 = 8x theoretical, ~7.7x effective The magic is in the QUALITY, not the ratio. AWQ achieves NF4-level compression with FP16-level accuracy.

— Why AWQ Won MLSys 2024 —

  1. Simple insight, huge impact
  2. Zero runtime overhead (transform is pre-computed)
  3. Works with ANY quantization method (INT4, NF4, whatever)
  4. State-of-the-art on LLaMA, OPT, BLOOM benchmarks

Implementation based on: MIT-HAN Lab + AutoAWQ

Types

AWQ configuration

pub type AWQConfig {
  AWQConfig(
    bits: Int,
    group_size: Int,
    alpha: Float,
    zero_point: Bool,
  )
}

Constructors

  • AWQConfig(
      bits: Int,
      group_size: Int,
      alpha: Float,
      zero_point: Bool,
    )

    Arguments

    bits

    Quantization bits (4 is standard, 3 is aggressive)

    group_size

    Group size for per-group scaling (128 is typical) Smaller = more accurate, larger = more compressed

    alpha

    Alpha exponent for scaling: scale = activation_stat ^ alpha 0.5 is empirically optimal (sqrt of activation magnitude) Higher alpha = more aggressive protection of salient channels

    zero_point

    Use zero-point (asymmetric quantization) Opinion: Skip it. The cache-miss overhead isn’t worth the accuracy gain.

Computed AWQ scales for weight transformation

pub type AWQScales {
  AWQScales(
    weight_scales: List(Float),
    activation_stats: List(Float),
    alpha: Float,
  )
}

Constructors

  • AWQScales(
      weight_scales: List(Float),
      activation_stats: List(Float),
      alpha: Float,
    )

    Arguments

    weight_scales

    Per-channel scale factors: multiply weights by these before quantizing

    activation_stats

    Original activation statistics (for debugging/analysis)

    alpha

    Alpha used in computation

AWQ-quantized tensor

pub type AWQTensor {
  AWQTensor(
    quantized_weights: List(Int),
    awq_scales: AWQScales,
    quant_scales: List(Float),
    zero_points: List(Int),
    shape: List(Int),
    memory_bytes: Int,
  )
}

Constructors

  • AWQTensor(
      quantized_weights: List(Int),
      awq_scales: AWQScales,
      quant_scales: List(Float),
      zero_points: List(Int),
      shape: List(Int),
      memory_bytes: Int,
    )

    Arguments

    quantized_weights

    Quantized weights (INT4 values)

    awq_scales

    AWQ channel scales (the secret sauce)

    quant_scales

    Per-group quantization scales

    zero_points

    Zero-points if asymmetric (usually empty)

    shape

    Original shape

    memory_bytes

    Memory in bytes

Values

pub fn apply_activation_transform(
  activations: List(Float),
  scales: AWQScales,
) -> List(Float)

Apply inverse transformation to activations: X’ = X * diag(1/s) This compensates for the weight scaling at runtime Note: In production, fuse this into the previous layer’s output

pub fn apply_weight_transform(
  weights: List(List(Float)),
  scales: AWQScales,
) -> List(List(Float))

Apply equivalent transformation to weights: W’ = W * diag(s) This scales salient channels UP before quantization

pub fn benchmark_awq() -> Nil
pub fn collect_activation_stats(
  activations_batch: List(List(Float)),
) -> List(Float)

Collect activation statistics from calibration data

This is THE critical step. Bad calibration = bad quantization. Use 128-512 samples from your actual inference distribution. More samples = more stable statistics, diminishing returns after 256.

Returns: mean absolute activation per channel

pub fn compute_awq_scales(
  activation_stats: List(Float),
  alpha: Float,
) -> AWQScales

Compute AWQ scales from activation statistics

Formula: scale[i] = activation_stat[i] ^ alpha

Why alpha = 0.5 (sqrt)?

  • Too low (0.1): Not enough protection for salient channels
  • Too high (0.9): Over-protection, wastes precision on outliers
  • 0.5: Empirically optimal across LLaMA, OPT, BLOOM
pub fn default_config() -> AWQConfig

Default AWQ config (matches AutoAWQ defaults)

pub fn dequantize_awq(awq: AWQTensor) -> tensor.Tensor

Dequantize AWQ tensor back to FP32 Note: Must also undo the AWQ weight transform

pub fn identify_salient_channels(
  activation_stats: List(Float),
  top_percent: Float,
) -> List(Int)

Identify the most salient channels (top-k by activation magnitude)

Key insight: Only ~1% of channels are truly salient. But they contribute ~10% of the output magnitude. Protecting them is the key to AWQ’s success.

pub fn main() -> Nil
pub fn quantize_awq(
  weights: tensor.Tensor,
  calibration_data: List(List(Float)),
  config: AWQConfig,
) -> AWQTensor

Complete AWQ quantization pipeline

Steps:

  1. Collect activation statistics (calibration)
  2. Compute per-channel AWQ scales
  3. Transform weights (scale up salient channels)
  4. Quantize transformed weights

At inference:

  • Use quantized weights directly
  • Apply inverse activation transform (fused into previous layer)
Search Document