Edifice.Generative.MAR (Edifice v0.2.0)

Copy Markdown View Source

MAR: Masked Autoregressive Generation.

MAR bridges autoregressive (AR) models and masked prediction models for iterative discrete token generation. Unlike pure AR models that generate tokens left-to-right, MAR predicts all masked positions in parallel and progressively unmasks tokens from most to least confident over K steps.

Architecture

Token Indices [batch, seq_len]
      |
      v
Token Embedding + Sinusoidal Positional Encoding
      |
      v
+--------------------------------------+
| Transformer Encoder (bidirectional)  |
|   N × (LayerNorm + MHA + FFN)        |
|   No causal mask  all positions see |
|   each other, conditioned on unmasked|
+--------------------------------------+
      |
      v
LayerNorm  Dense  Logits [batch, seq_len, vocab_size]

Training

At each training step:

  1. Sample mask ratio r ~ cosine schedule (biased toward moderate r)
  2. Randomly mask r fraction of tokens with [MASK] id
  3. Forward pass → logits for all positions
  4. Compute cross-entropy only on masked positions
L = -1/|M| Σ_{i  M} log p(y_i | context)

Inference (Iterative Decoding)

Starts fully masked; unmasking is driven by model confidence:

  1. Initialise: all tokens = [MASK]
  2. For step k = 1..K: a. Forward pass — predict logits at all masked positions b. Compute confidence = max softmax probability per masked token c. Unmask the n_k tokens with highest confidence (n_k increases each step so all are revealed by step K)
  3. Return the fully unmasked sequence

Usage

model = MAR.build(
  vocab_size: 8192,
  embed_dim: 256,
  num_layers: 6,
  num_heads: 8,
  seq_len: 256
)

# Training
loss = MAR.mar_loss(logits, targets, mask)

# Inference
tokens = MAR.iterative_decode(model, params, seq_len: 256, vocab_size: 8192)

References

Summary

Types

Options for build/1.

Functions

Build a MAR transformer model.

Iterative masked decoding for MAR inference.

MAR training loss: cross-entropy over masked positions only.

Sample a random masking ratio from the cosine schedule used in MAR training.

Types

build_opt()

@type build_opt() ::
  {:vocab_size, pos_integer()}
  | {:embed_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:dropout, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a MAR transformer model.

Options

  • :vocab_size - Vocabulary size, including [MASK] token at id 0 (required)
  • :embed_dim - Embedding / model hidden dimension (default: 256)
  • :num_layers - Number of bidirectional encoder layers (default: 6)
  • :num_heads - Number of attention heads (default: 8)
  • :seq_len - Sequence length for positional encoding (default: 256)
  • :dropout - Dropout rate applied after embedding and in each block (default: 0.1)

Returns

An Axon model taking token indices [batch, seq_len] (integer) and returning logits [batch, seq_len, vocab_size].

iterative_decode(model, params, opts \\ [])

@spec iterative_decode(Axon.t(), term(), keyword()) :: Nx.Tensor.t()

Iterative masked decoding for MAR inference.

Starts with all tokens masked, then unmaskes the most-confident tokens first over num_steps iterations.

Parameters

  • model - Axon model from build/1
  • params - Initialised model state (from Axon.build/2 init function)

Options

  • :num_steps - Number of decoding iterations K (default: 8)
  • :seq_len - Sequence length (required)
  • :vocab_size - Vocabulary size (required)
  • :mask_token_id - ID used for [MASK] tokens (default: 0)
  • :temperature - Softmax temperature (default: 1.0)

Returns

Token indices tensor [1, seq_len].

mar_loss(logits, targets, mask)

@spec mar_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

MAR training loss: cross-entropy over masked positions only.

Parameters

  • logits - Model output [batch, seq_len, vocab_size]
  • targets - Ground-truth token indices [batch, seq_len]
  • mask - Binary mask [batch, seq_len]; 1 = masked (predict), 0 = unmasked

Returns

Scalar loss (mean CE over masked tokens).

sample_mask_ratio()

@spec sample_mask_ratio() :: float()

Sample a random masking ratio from the cosine schedule used in MAR training.

Draws u ~ Uniform[0,1] and returns 1 - cos(π·u/2), which biases sampling toward moderate mask fractions (away from 0 and 1).

Returns

Float in [0.0, 1.0].