Quantization-Aware Training (QAT) — transformer with quantized linear layers.
Extends BitNet's quantization-aware training to support multiple bit widths beyond binary and ternary. All dense layers in the attention and FFN sub-layers use quantized forward passes (with straight-through gradient estimation for backpropagation).
Quantization Modes
| Mode | Weight Values | Levels | Bits/Weight |
|---|---|---|---|
:binary | {-1, +1} | 2 | 1 |
:ternary | {-1, 0, +1} | 3 | 1.58 |
:int4 | 16 absmax-scaled | 16 | 4 |
:int8 | 256 absmax-scaled | 256 | 8 |
Architecture
Input [batch, seq_len, embed_dim]
|
Quantized blocks: Pre-norm -> QuantLinear(QKV) -> Attention -> Residual
Pre-norm -> QuantLinear(FFN) -> Residual
|
Final norm -> last timestep -> [batch, hidden_size]Usage
model = QAT.build(
embed_dim: 256,
hidden_size: 256,
num_heads: 4,
num_layers: 4,
quantize: :int4
)References
- Wang et al., "BitNet: Scaling 1-Bit Transformers" (2023)
- Jacob et al., "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" (2018)
Summary
Functions
Build a QAT model for sequence processing.
Build a single QAT transformer block with quantized linear layers.
Get the output size of a QAT model.
Build a quantized linear layer with the given quantization mode.
Get recommended defaults.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:quantize, :binary | :ternary | :int4 | :int8} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a QAT model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):num_layers- Number of blocks (default: 4):quantize- Quantization mode::binary,:ternary,:int4, or:int8(default::ternary):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length (default: 60)
Returns
An Axon model outputting [batch, hidden_size] from the last position.
Build a single QAT transformer block with quantized linear layers.
@spec output_size(keyword()) :: pos_integer()
Get the output size of a QAT model.
@spec quant_linear(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a quantized linear layer with the given quantization mode.
Stores full-precision weights for gradient updates; quantizes in the forward pass via straight-through estimation.
@spec recommended_defaults() :: keyword()
Get recommended defaults.