# `Edifice.Feedforward.BitNet`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/feedforward/bitnet.ex#L1)

BitNet: 1-bit/1.58-bit transformer with ternary weight quantization.

Implements the BitNet architecture from "BitNet: Scaling 1-Bit Transformers
for Large Language Models" (Wang et al., 2023) and "The Era of 1-bit LLMs"
(Ma et al., 2024). BitNet quantizes weights to {-1, 0, +1} in the forward
pass while maintaining full-precision weights for gradient updates.

## Key Innovation: Quantization-Aware Forward Pass

BitNet uses "BitLinear" layers that replace standard dense layers:
1. Full-precision weights are stored for training
2. In the forward pass, weights are quantized to binary ({-1, +1}) or
   ternary ({-1, 0, +1}) values
3. Activations are quantized to 8-bit using absmax quantization
4. Gradients flow through the quantization via straight-through estimator

```
BitLinear(x):
  W_quant = quantize_weights(W)   # Binary: sign(W), Ternary: round(W/mean(|W|))
  x_quant = quantize_activations(x)  # absmax to [-128, 127]
  output = x_quant @ W_quant^T * scale
```

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
+-----------------------+
| Input Projection      |
+-----------------------+
      |
      v
+-----------------------+
| BitNet Block x N      |
|  Norm -> BitAttn      |
|  Norm -> BitFFN       |
|  (all dense layers    |
|   use BitLinear)      |
+-----------------------+
      |
      v
[batch, hidden_size]    (last timestep)
```

## Quantization Modes

| Mode    | Weight Values | Bits per Weight |
|---------|--------------|-----------------|
| Binary  | {-1, +1}     | 1 bit           |
| Ternary | {-1, 0, +1}  | 1.58 bits       |

## Usage

    model = BitNet.build(
      embed_dim: 287,
      hidden_size: 256,
      num_heads: 4,
      num_layers: 4,
      quantize: :ternary
    )

## References

- "BitNet: Scaling 1-Bit Transformers for Large Language Models"
- "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits"
- arXiv: https://arxiv.org/abs/2310.11453

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:quantize, :ternary | :binary}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `bitlinear`

```elixir
@spec bitlinear(Axon.t(), pos_integer(), keyword()) :: Axon.t()
```

Build a BitLinear layer: dense with quantized weights in forward pass.

Uses `Axon.param` for full-precision weights and quantizes them during
the forward pass via `Axon.layer`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a BitNet model for sequence processing.

## Options

  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_heads` - Number of attention heads (default: 4)
  - `:num_layers` - Number of BitNet blocks (default: 4)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:window_size` - Expected sequence length (default: 60)
  - `:quantize` - Quantization mode: `:ternary` or `:binary` (default: :ternary)

## Returns

  An Axon model that outputs [batch, hidden_size] from the last position.

# `build_bitnet_block`

```elixir
@spec build_bitnet_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single BitNet transformer block with quantized linear layers.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a BitNet model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
