Edifice.Feedforward.BitNet (Edifice v0.2.0)

Copy Markdown View Source

BitNet: 1-bit/1.58-bit transformer with ternary weight quantization.

Implements the BitNet architecture from "BitNet: Scaling 1-Bit Transformers for Large Language Models" (Wang et al., 2023) and "The Era of 1-bit LLMs" (Ma et al., 2024). BitNet quantizes weights to {-1, 0, +1} in the forward pass while maintaining full-precision weights for gradient updates.

Key Innovation: Quantization-Aware Forward Pass

BitNet uses "BitLinear" layers that replace standard dense layers:

  1. Full-precision weights are stored for training
  2. In the forward pass, weights are quantized to binary ({-1, +1}) or ternary ({-1, 0, +1}) values
  3. Activations are quantized to 8-bit using absmax quantization
  4. Gradients flow through the quantization via straight-through estimator
BitLinear(x):
  W_quant = quantize_weights(W)   # Binary: sign(W), Ternary: round(W/mean(|W|))
  x_quant = quantize_activations(x)  # absmax to [-128, 127]
  output = x_quant @ W_quant^T * scale

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-----------------------+
| Input Projection      |
+-----------------------+
      |
      v
+-----------------------+
| BitNet Block x N      |
|  Norm -> BitAttn      |
|  Norm -> BitFFN       |
|  (all dense layers    |
|   use BitLinear)      |
+-----------------------+
      |
      v
[batch, hidden_size]    (last timestep)

Quantization Modes

ModeWeight ValuesBits per Weight
Binary{-1, +1}1 bit
Ternary{-1, 0, +1}1.58 bits

Usage

model = BitNet.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 4,
  num_layers: 4,
  quantize: :ternary
)

References

  • "BitNet: Scaling 1-Bit Transformers for Large Language Models"
  • "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits"
  • arXiv: https://arxiv.org/abs/2310.11453

Summary

Types

Options for build/1.

Functions

Build a BitLinear layer: dense with quantized weights in forward pass.

Build a BitNet model for sequence processing.

Build a single BitNet transformer block with quantized linear layers.

Get the output size of a BitNet model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:quantize, :ternary | :binary}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

bitlinear(input, output_size, opts \\ [])

@spec bitlinear(Axon.t(), pos_integer(), keyword()) :: Axon.t()

Build a BitLinear layer: dense with quantized weights in forward pass.

Uses Axon.param for full-precision weights and quantizes them during the forward pass via Axon.layer.

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a BitNet model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :num_layers - Number of BitNet blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length (default: 60)
  • :quantize - Quantization mode: :ternary or :binary (default: :ternary)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_bitnet_block(input, opts)

@spec build_bitnet_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single BitNet transformer block with quantized linear layers.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a BitNet model.