Edifice.Meta.QAT (Edifice v0.2.0)

Copy Markdown View Source

Quantization-Aware Training (QAT) — transformer with quantized linear layers.

Extends BitNet's quantization-aware training to support multiple bit widths beyond binary and ternary. All dense layers in the attention and FFN sub-layers use quantized forward passes (with straight-through gradient estimation for backpropagation).

Quantization Modes

ModeWeight ValuesLevelsBits/Weight
:binary{-1, +1}21
:ternary{-1, 0, +1}31.58
:int416 absmax-scaled164
:int8256 absmax-scaled2568

Architecture

Input [batch, seq_len, embed_dim]
      |
Quantized blocks: Pre-norm -> QuantLinear(QKV) -> Attention -> Residual
                   Pre-norm -> QuantLinear(FFN) -> Residual
      |
Final norm -> last timestep -> [batch, hidden_size]

Usage

model = QAT.build(
  embed_dim: 256,
  hidden_size: 256,
  num_heads: 4,
  num_layers: 4,
  quantize: :int4
)

References

  • Wang et al., "BitNet: Scaling 1-Bit Transformers" (2023)
  • Jacob et al., "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" (2018)

Summary

Types

Options for build/1.

Functions

Build a QAT model for sequence processing.

Build a single QAT transformer block with quantized linear layers.

Get the output size of a QAT model.

Build a quantized linear layer with the given quantization mode.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:quantize, :binary | :ternary | :int4 | :int8}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a QAT model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :num_layers - Number of blocks (default: 4)
  • :quantize - Quantization mode: :binary, :ternary, :int4, or :int8 (default: :ternary)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model outputting [batch, hidden_size] from the last position.

build_qat_block(input, opts)

@spec build_qat_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single QAT transformer block with quantized linear layers.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a QAT model.

quant_linear(input, output_size, opts \\ [])

@spec quant_linear(Axon.t(), pos_integer(), keyword()) :: Axon.t()

Build a quantized linear layer with the given quantization mode.

Stores full-precision weights for gradient updates; quantizes in the forward pass via straight-through estimation.