BitNet: 1-bit/1.58-bit transformer with ternary weight quantization.
Implements the BitNet architecture from "BitNet: Scaling 1-Bit Transformers for Large Language Models" (Wang et al., 2023) and "The Era of 1-bit LLMs" (Ma et al., 2024). BitNet quantizes weights to {-1, 0, +1} in the forward pass while maintaining full-precision weights for gradient updates.
Key Innovation: Quantization-Aware Forward Pass
BitNet uses "BitLinear" layers that replace standard dense layers:
- Full-precision weights are stored for training
- In the forward pass, weights are quantized to binary ({-1, +1}) or ternary ({-1, 0, +1}) values
- Activations are quantized to 8-bit using absmax quantization
- Gradients flow through the quantization via straight-through estimator
BitLinear(x):
W_quant = quantize_weights(W) # Binary: sign(W), Ternary: round(W/mean(|W|))
x_quant = quantize_activations(x) # absmax to [-128, 127]
output = x_quant @ W_quant^T * scaleArchitecture
Input [batch, seq_len, embed_dim]
|
v
+-----------------------+
| Input Projection |
+-----------------------+
|
v
+-----------------------+
| BitNet Block x N |
| Norm -> BitAttn |
| Norm -> BitFFN |
| (all dense layers |
| use BitLinear) |
+-----------------------+
|
v
[batch, hidden_size] (last timestep)Quantization Modes
| Mode | Weight Values | Bits per Weight |
|---|---|---|
| Binary | {-1, +1} | 1 bit |
| Ternary | {-1, 0, +1} | 1.58 bits |
Usage
model = BitNet.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 4,
num_layers: 4,
quantize: :ternary
)References
- "BitNet: Scaling 1-Bit Transformers for Large Language Models"
- "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits"
- arXiv: https://arxiv.org/abs/2310.11453
Summary
Functions
Build a BitLinear layer: dense with quantized weights in forward pass.
Build a BitNet model for sequence processing.
Build a single BitNet transformer block with quantized linear layers.
Get the output size of a BitNet model.
Get recommended defaults.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:quantize, :ternary | :binary} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
@spec bitlinear(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a BitLinear layer: dense with quantized weights in forward pass.
Uses Axon.param for full-precision weights and quantizes them during
the forward pass via Axon.layer.
Build a BitNet model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):num_layers- Number of BitNet blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length (default: 60):quantize- Quantization mode::ternaryor:binary(default: :ternary)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build a single BitNet transformer block with quantized linear layers.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a BitNet model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.