MAR: Masked Autoregressive Generation.
MAR bridges autoregressive (AR) models and masked prediction models for iterative discrete token generation. Unlike pure AR models that generate tokens left-to-right, MAR predicts all masked positions in parallel and progressively unmasks tokens from most to least confident over K steps.
Architecture
Token Indices [batch, seq_len]
|
v
Token Embedding + Sinusoidal Positional Encoding
|
v
+--------------------------------------+
| Transformer Encoder (bidirectional) |
| N × (LayerNorm + MHA + FFN) |
| No causal mask — all positions see |
| each other, conditioned on unmasked|
+--------------------------------------+
|
v
LayerNorm → Dense → Logits [batch, seq_len, vocab_size]Training
At each training step:
- Sample mask ratio r ~ cosine schedule (biased toward moderate r)
- Randomly mask r fraction of tokens with [MASK] id
- Forward pass → logits for all positions
- Compute cross-entropy only on masked positions
L = -1/|M| Σ_{i ∈ M} log p(y_i | context)Inference (Iterative Decoding)
Starts fully masked; unmasking is driven by model confidence:
- Initialise: all tokens = [MASK]
- For step k = 1..K:
a. Forward pass — predict logits at all masked positions
b. Compute confidence = max softmax probability per masked token
c. Unmask the
n_ktokens with highest confidence (n_kincreases each step so all are revealed by step K) - Return the fully unmasked sequence
Usage
model = MAR.build(
vocab_size: 8192,
embed_dim: 256,
num_layers: 6,
num_heads: 8,
seq_len: 256
)
# Training
loss = MAR.mar_loss(logits, targets, mask)
# Inference
tokens = MAR.iterative_decode(model, params, seq_len: 256, vocab_size: 8192)References
- Li et al., "Autoregressive Image Generation without Vector Quantization" (2024)
- https://arxiv.org/abs/2406.11838
Summary
Functions
Build a MAR transformer model.
Iterative masked decoding for MAR inference.
MAR training loss: cross-entropy over masked positions only.
Sample a random masking ratio from the cosine schedule used in MAR training.
Types
@type build_opt() :: {:vocab_size, pos_integer()} | {:embed_dim, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:seq_len, pos_integer()} | {:dropout, float()}
Options for build/1.
Functions
Build a MAR transformer model.
Options
:vocab_size- Vocabulary size, including [MASK] token at id 0 (required):embed_dim- Embedding / model hidden dimension (default: 256):num_layers- Number of bidirectional encoder layers (default: 6):num_heads- Number of attention heads (default: 8):seq_len- Sequence length for positional encoding (default: 256):dropout- Dropout rate applied after embedding and in each block (default: 0.1)
Returns
An Axon model taking token indices [batch, seq_len] (integer) and
returning logits [batch, seq_len, vocab_size].
@spec iterative_decode(Axon.t(), term(), keyword()) :: Nx.Tensor.t()
Iterative masked decoding for MAR inference.
Starts with all tokens masked, then unmaskes the most-confident tokens
first over num_steps iterations.
Parameters
model- Axon model frombuild/1params- Initialised model state (fromAxon.build/2init function)
Options
:num_steps- Number of decoding iterations K (default: 8):seq_len- Sequence length (required):vocab_size- Vocabulary size (required):mask_token_id- ID used for [MASK] tokens (default: 0):temperature- Softmax temperature (default: 1.0)
Returns
Token indices tensor [1, seq_len].
@spec mar_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
MAR training loss: cross-entropy over masked positions only.
Parameters
logits- Model output[batch, seq_len, vocab_size]targets- Ground-truth token indices[batch, seq_len]mask- Binary mask[batch, seq_len]; 1 = masked (predict), 0 = unmasked
Returns
Scalar loss (mean CE over masked tokens).
@spec sample_mask_ratio() :: float()
Sample a random masking ratio from the cosine schedule used in MAR training.
Draws u ~ Uniform[0,1] and returns 1 - cos(π·u/2), which biases
sampling toward moderate mask fractions (away from 0 and 1).
Returns
Float in [0.0, 1.0].