Edifice.Transformer.ByteLatentTransformer (Edifice v0.2.0)

Copy Markdown View Source

Byte Latent Transformer (BLT) — byte-level processing via encode-process-decode.

BLT processes raw byte sequences by encoding bytes into latent patches, processing them with a powerful latent transformer, then decoding back to byte-level predictions. This avoids the need for a fixed tokenizer.

Architecture

Three-component pipeline:

Byte IDs [batch, byte_len]
      |
+----- Encoder -------------------------------------------+
|  Embedding(256, byte_dim) + transformer blocks          |
|  Strided mean pool (patch_size stride) + project        |
|   [batch, byte_len/patch_size, latent_dim]             |
+----------------------------------------------------------+
      |
+----- Latent Transformer --------------------------------+
|  GQA + RoPE + SwiGLU (DecoderOnly-style)               |
|  output_mode: :all                                      |
|   [batch, byte_len/patch_size, latent_dim]             |
+----------------------------------------------------------+
      |
+----- Decoder -------------------------------------------+
|  Project + upsample (repeat) + transformer blocks       |
|  Dense(vocab_size)                                      |
|   [batch, byte_len, vocab_size]                        |
+----------------------------------------------------------+

Returns

A 3-tuple {encoder, latent_transformer, decoder} where each is an independent Axon model.

Usage

{encoder, latent, decoder} = ByteLatentTransformer.build(
  vocab_size: 256,
  patch_size: 4,
  latent_dim: 256,
  byte_dim: 64,
  max_byte_len: 256
)

References

Summary

Types

Options for build/1.

Functions

Build a Byte Latent Transformer.

Get the output size of the latent transformer.

Types

build_opt()

@type build_opt() ::
  {:vocab_size, pos_integer()}
  | {:patch_size, pos_integer()}
  | {:latent_dim, pos_integer()}
  | {:byte_dim, pos_integer()}
  | {:num_encoder_layers, pos_integer()}
  | {:num_latent_layers, pos_integer()}
  | {:num_decoder_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:max_byte_len, pos_integer()}
  | {:dropout, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t(), Axon.t()}

Build a Byte Latent Transformer.

Options

  • :vocab_size - Byte vocabulary size (default: 256)
  • :patch_size - Number of bytes per latent patch (default: 4)
  • :latent_dim - Latent transformer hidden dimension (default: 256)
  • :byte_dim - Byte-level encoder/decoder hidden dimension (default: 64)
  • :num_encoder_layers - Encoder transformer layers (default: 2)
  • :num_latent_layers - Latent transformer layers (default: 4)
  • :num_decoder_layers - Decoder transformer layers (default: 2)
  • :num_heads - Attention heads for latent transformer (default: 4)
  • :num_kv_heads - KV heads for GQA (default: 2)
  • :max_byte_len - Maximum byte sequence length (default: 256). Must be divisible by patch_size.
  • :dropout - Dropout rate (default: 0.1)

Returns

{encoder, latent_transformer, decoder} — a 3-tuple of Axon models.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the latent transformer.