# `Edifice.Generative.MAR`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/mar.ex#L1)

MAR: Masked Autoregressive Generation.

MAR bridges autoregressive (AR) models and masked prediction models for
iterative discrete token generation. Unlike pure AR models that generate
tokens left-to-right, MAR predicts all masked positions in parallel and
progressively unmasks tokens from most to least confident over K steps.

## Architecture

```
Token Indices [batch, seq_len]
      |
      v
Token Embedding + Sinusoidal Positional Encoding
      |
      v
+--------------------------------------+
| Transformer Encoder (bidirectional)  |
|   N × (LayerNorm + MHA + FFN)        |
|   No causal mask — all positions see |
|   each other, conditioned on unmasked|
+--------------------------------------+
      |
      v
LayerNorm → Dense → Logits [batch, seq_len, vocab_size]
```

## Training

At each training step:
1. Sample mask ratio r ~ cosine schedule (biased toward moderate r)
2. Randomly mask r fraction of tokens with [MASK] id
3. Forward pass → logits for all positions
4. Compute cross-entropy only on masked positions

```
L = -1/|M| Σ_{i ∈ M} log p(y_i | context)
```

## Inference (Iterative Decoding)

Starts fully masked; unmasking is driven by model confidence:

1. Initialise: all tokens = [MASK]
2. For step k = 1..K:
   a. Forward pass — predict logits at all masked positions
   b. Compute confidence = max softmax probability per masked token
   c. Unmask the `n_k` tokens with highest confidence
      (`n_k` increases each step so all are revealed by step K)
3. Return the fully unmasked sequence

## Usage

    model = MAR.build(
      vocab_size: 8192,
      embed_dim: 256,
      num_layers: 6,
      num_heads: 8,
      seq_len: 256
    )

    # Training
    loss = MAR.mar_loss(logits, targets, mask)

    # Inference
    tokens = MAR.iterative_decode(model, params, seq_len: 256, vocab_size: 8192)

## References

- Li et al., "Autoregressive Image Generation without Vector Quantization" (2024)
- https://arxiv.org/abs/2406.11838

# `build_opt`

```elixir
@type build_opt() ::
  {:vocab_size, pos_integer()}
  | {:embed_dim, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:dropout, float()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a MAR transformer model.

## Options

  - `:vocab_size` - Vocabulary size, including [MASK] token at id 0 (required)
  - `:embed_dim` - Embedding / model hidden dimension (default: 256)
  - `:num_layers` - Number of bidirectional encoder layers (default: 6)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:seq_len` - Sequence length for positional encoding (default: 256)
  - `:dropout` - Dropout rate applied after embedding and in each block (default: 0.1)

## Returns

  An Axon model taking token indices `[batch, seq_len]` (integer) and
  returning logits `[batch, seq_len, vocab_size]`.

# `iterative_decode`

```elixir
@spec iterative_decode(Axon.t(), term(), keyword()) :: Nx.Tensor.t()
```

Iterative masked decoding for MAR inference.

Starts with all tokens masked, then unmaskes the most-confident tokens
first over `num_steps` iterations.

## Parameters

  - `model` - Axon model from `build/1`
  - `params` - Initialised model state (from `Axon.build/2` init function)

## Options

  - `:num_steps` - Number of decoding iterations K (default: 8)
  - `:seq_len` - Sequence length (required)
  - `:vocab_size` - Vocabulary size (required)
  - `:mask_token_id` - ID used for [MASK] tokens (default: 0)
  - `:temperature` - Softmax temperature (default: 1.0)

## Returns

  Token indices tensor `[1, seq_len]`.

# `mar_loss`

```elixir
@spec mar_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

MAR training loss: cross-entropy over masked positions only.

## Parameters

  - `logits` - Model output `[batch, seq_len, vocab_size]`
  - `targets` - Ground-truth token indices `[batch, seq_len]`
  - `mask` - Binary mask `[batch, seq_len]`; 1 = masked (predict), 0 = unmasked

## Returns

  Scalar loss (mean CE over masked tokens).

# `sample_mask_ratio`

```elixir
@spec sample_mask_ratio() :: float()
```

Sample a random masking ratio from the cosine schedule used in MAR training.

Draws `u ~ Uniform[0,1]` and returns `1 - cos(π·u/2)`, which biases
sampling toward moderate mask fractions (away from 0 and 1).

## Returns

  Float in `[0.0, 1.0]`.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
