Byte Latent Transformer (BLT) — byte-level processing via encode-process-decode.
BLT processes raw byte sequences by encoding bytes into latent patches, processing them with a powerful latent transformer, then decoding back to byte-level predictions. This avoids the need for a fixed tokenizer.
Architecture
Three-component pipeline:
Byte IDs [batch, byte_len]
|
+----- Encoder -------------------------------------------+
| Embedding(256, byte_dim) + transformer blocks |
| Strided mean pool (patch_size stride) + project |
| → [batch, byte_len/patch_size, latent_dim] |
+----------------------------------------------------------+
|
+----- Latent Transformer --------------------------------+
| GQA + RoPE + SwiGLU (DecoderOnly-style) |
| output_mode: :all |
| → [batch, byte_len/patch_size, latent_dim] |
+----------------------------------------------------------+
|
+----- Decoder -------------------------------------------+
| Project + upsample (repeat) + transformer blocks |
| Dense(vocab_size) |
| → [batch, byte_len, vocab_size] |
+----------------------------------------------------------+Returns
A 3-tuple {encoder, latent_transformer, decoder} where each is an
independent Axon model.
Usage
{encoder, latent, decoder} = ByteLatentTransformer.build(
vocab_size: 256,
patch_size: 4,
latent_dim: 256,
byte_dim: 64,
max_byte_len: 256
)References
- "Byte Latent Transformer: Patches Scale Better Than Tokens" (Meta, 2024) — https://arxiv.org/abs/2412.09871
Summary
Types
@type build_opt() :: {:vocab_size, pos_integer()} | {:patch_size, pos_integer()} | {:latent_dim, pos_integer()} | {:byte_dim, pos_integer()} | {:num_encoder_layers, pos_integer()} | {:num_latent_layers, pos_integer()} | {:num_decoder_layers, pos_integer()} | {:num_heads, pos_integer()} | {:max_byte_len, pos_integer()} | {:dropout, float()}
Options for build/1.
Functions
Build a Byte Latent Transformer.
Options
:vocab_size- Byte vocabulary size (default: 256):patch_size- Number of bytes per latent patch (default: 4):latent_dim- Latent transformer hidden dimension (default: 256):byte_dim- Byte-level encoder/decoder hidden dimension (default: 64):num_encoder_layers- Encoder transformer layers (default: 2):num_latent_layers- Latent transformer layers (default: 4):num_decoder_layers- Decoder transformer layers (default: 2):num_heads- Attention heads for latent transformer (default: 4):num_kv_heads- KV heads for GQA (default: 2):max_byte_len- Maximum byte sequence length (default: 256). Must be divisible bypatch_size.:dropout- Dropout rate (default: 0.1)
Returns
{encoder, latent_transformer, decoder} — a 3-tuple of Axon models.
@spec output_size(keyword()) :: pos_integer()
Get the output size of the latent transformer.