viva_tensor/quant/awq
AWQ (Activation-aware Weight Quantization)
Reference: Lin et al. (2024) - “AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration” MLSys 2024 BEST PAPER AWARD https://arxiv.org/abs/2306.00978
— The Key Insight (worth repeating) — Only ~1% of weights are “salient” - and they matter 10x more than the rest. But here’s the twist: you identify them by looking at ACTIVATIONS, not weights. High activation magnitude = that channel matters = protect those weights.
— The Genius — Don’t modify the quantization algorithm. Modify the weights BEFORE quantizing. Scale salient channels UP by s, then scale activations DOWN by 1/s. Mathematically equivalent: WX = (sW)(X/s) But now the important weights have more precision after quantization.
— Compression Math — Same as NF4/INT4: 32/4 = 8x theoretical, ~7.7x effective The magic is in the QUALITY, not the ratio. AWQ achieves NF4-level compression with FP16-level accuracy.
— Why AWQ Won MLSys 2024 —
- Simple insight, huge impact
- Zero runtime overhead (transform is pre-computed)
- Works with ANY quantization method (INT4, NF4, whatever)
- State-of-the-art on LLaMA, OPT, BLOOM benchmarks
Implementation based on: MIT-HAN Lab + AutoAWQ
Types
AWQ configuration
pub type AWQConfig {
AWQConfig(
bits: Int,
group_size: Int,
alpha: Float,
zero_point: Bool,
)
}
Constructors
-
AWQConfig( bits: Int, group_size: Int, alpha: Float, zero_point: Bool, )Arguments
- bits
-
Quantization bits (4 is standard, 3 is aggressive)
- group_size
-
Group size for per-group scaling (128 is typical) Smaller = more accurate, larger = more compressed
- alpha
-
Alpha exponent for scaling: scale = activation_stat ^ alpha 0.5 is empirically optimal (sqrt of activation magnitude) Higher alpha = more aggressive protection of salient channels
- zero_point
-
Use zero-point (asymmetric quantization) Opinion: Skip it. The cache-miss overhead isn’t worth the accuracy gain.
Computed AWQ scales for weight transformation
pub type AWQScales {
AWQScales(
weight_scales: List(Float),
activation_stats: List(Float),
alpha: Float,
)
}
Constructors
-
AWQScales( weight_scales: List(Float), activation_stats: List(Float), alpha: Float, )Arguments
- weight_scales
-
Per-channel scale factors: multiply weights by these before quantizing
- activation_stats
-
Original activation statistics (for debugging/analysis)
- alpha
-
Alpha used in computation
AWQ-quantized tensor
pub type AWQTensor {
AWQTensor(
quantized_weights: List(Int),
awq_scales: AWQScales,
quant_scales: List(Float),
zero_points: List(Int),
shape: List(Int),
memory_bytes: Int,
)
}
Constructors
-
AWQTensor( quantized_weights: List(Int), awq_scales: AWQScales, quant_scales: List(Float), zero_points: List(Int), shape: List(Int), memory_bytes: Int, )Arguments
- quantized_weights
-
Quantized weights (INT4 values)
- awq_scales
-
AWQ channel scales (the secret sauce)
- quant_scales
-
Per-group quantization scales
- zero_points
-
Zero-points if asymmetric (usually empty)
- shape
-
Original shape
- memory_bytes
-
Memory in bytes
Values
pub fn apply_activation_transform(
activations: List(Float),
scales: AWQScales,
) -> List(Float)
Apply inverse transformation to activations: X’ = X * diag(1/s) This compensates for the weight scaling at runtime Note: In production, fuse this into the previous layer’s output
pub fn apply_weight_transform(
weights: List(List(Float)),
scales: AWQScales,
) -> List(List(Float))
Apply equivalent transformation to weights: W’ = W * diag(s) This scales salient channels UP before quantization
pub fn benchmark_awq() -> Nil
pub fn collect_activation_stats(
activations_batch: List(List(Float)),
) -> List(Float)
Collect activation statistics from calibration data
This is THE critical step. Bad calibration = bad quantization. Use 128-512 samples from your actual inference distribution. More samples = more stable statistics, diminishing returns after 256.
Returns: mean absolute activation per channel
pub fn compute_awq_scales(
activation_stats: List(Float),
alpha: Float,
) -> AWQScales
Compute AWQ scales from activation statistics
Formula: scale[i] = activation_stat[i] ^ alpha
Why alpha = 0.5 (sqrt)?
- Too low (0.1): Not enough protection for salient channels
- Too high (0.9): Over-protection, wastes precision on outliers
- 0.5: Empirically optimal across LLaMA, OPT, BLOOM
pub fn default_config() -> AWQConfig
Default AWQ config (matches AutoAWQ defaults)
pub fn dequantize_awq(awq: AWQTensor) -> tensor.Tensor
Dequantize AWQ tensor back to FP32 Note: Must also undo the AWQ weight transform
pub fn identify_salient_channels(
activation_stats: List(Float),
top_percent: Float,
) -> List(Int)
Identify the most salient channels (top-k by activation magnitude)
Key insight: Only ~1% of channels are truly salient. But they contribute ~10% of the output magnitude. Protecting them is the key to AWQ’s success.
pub fn quantize_awq(
weights: tensor.Tensor,
calibration_data: List(List(Float)),
config: AWQConfig,
) -> AWQTensor
Complete AWQ quantization pipeline
Steps:
- Collect activation statistics (calibration)
- Compute per-channel AWQ scales
- Transform weights (scale up salient channels)
- Quantize transformed weights
At inference:
- Use quantized weights directly
- Apply inverse activation transform (fused into previous layer)