# `Edifice.Meta.QAT`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/meta/qat.ex#L1)

Quantization-Aware Training (QAT) — transformer with quantized linear layers.

Extends BitNet's quantization-aware training to support multiple bit widths
beyond binary and ternary. All dense layers in the attention and FFN sub-layers
use quantized forward passes (with straight-through gradient estimation for
backpropagation).

## Quantization Modes

| Mode      | Weight Values       | Levels | Bits/Weight |
|-----------|--------------------| -------|------------|
| `:binary`  | {-1, +1}            | 2      | 1          |
| `:ternary` | {-1, 0, +1}         | 3      | 1.58       |
| `:int4`    | 16 absmax-scaled    | 16     | 4          |
| `:int8`    | 256 absmax-scaled   | 256    | 8          |

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
Quantized blocks: Pre-norm -> QuantLinear(QKV) -> Attention -> Residual
                   Pre-norm -> QuantLinear(FFN) -> Residual
      |
Final norm -> last timestep -> [batch, hidden_size]
```

## Usage

    model = QAT.build(
      embed_dim: 256,
      hidden_size: 256,
      num_heads: 4,
      num_layers: 4,
      quantize: :int4
    )

## References

- Wang et al., "BitNet: Scaling 1-Bit Transformers" (2023)
- Jacob et al., "Quantization and Training of Neural Networks for Efficient
  Integer-Arithmetic-Only Inference" (2018)

# `build_opt`

```elixir
@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:quantize, :binary | :ternary | :int4 | :int8}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a QAT model for sequence processing.

## Options

  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_heads` - Number of attention heads (default: 4)
  - `:num_layers` - Number of blocks (default: 4)
  - `:quantize` - Quantization mode: `:binary`, `:ternary`, `:int4`, or `:int8` (default: `:ternary`)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:window_size` - Expected sequence length (default: 60)

## Returns

  An Axon model outputting `[batch, hidden_size]` from the last position.

# `build_qat_block`

```elixir
@spec build_qat_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single QAT transformer block with quantized linear layers.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output size of a QAT model.

# `quant_linear`

```elixir
@spec quant_linear(Axon.t(), pos_integer(), keyword()) :: Axon.t()
```

Build a quantized linear layer with the given quantization mode.

Stores full-precision weights for gradient updates; quantizes in the
forward pass via straight-through estimation.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
