# `Edifice.Generative.Transfusion`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/transfusion.ex#L1)

Transfusion: Unified Autoregressive Text + Diffusion Image Generation.

<!-- verified: true, date: 2026-02-23 -->

A single transformer model that jointly handles discrete text tokens
(autoregressive next-token prediction) and continuous image patches
(denoising diffusion) in one shared backbone.

## Key Innovation: Mixed Attention Mask

Text tokens and image patches share the same transformer layers, but
attend with different masks:

- **Text positions** (causal): each token sees only preceding tokens
- **Image positions** (bidirectional within image): each patch sees all
  other patches in the same image, plus all preceding text context

Combined rule:
```
mask[i, j] = 1  if  j ≤ i                        # causal for text
             OR  (image[i] AND image[j])            # bidir within image
```

## Architecture

```
Inputs: sequence [batch, seq_len, embed_dim]  (text embeddings + image patches)
        modality_mask [batch, seq_len]          (0=text, 1=image)
        timestep [batch]                        (diffusion step for image)
        |
        v
Modality type embedding  (learnable TEXT / IMAGE vectors added to tokens)
        |
        v
Input projection  →  hidden_size
        |
        v
+----------------------------------------------+
|  Transfusion Block  ×  num_layers             |
|                                               |
|  Add time_embed at image positions            |
|  LayerNorm  →  Mixed Attention  →  Residual   |
|  LayerNorm  →  FFN (GELU)       →  Residual   |
+----------------------------------------------+
        |
Final LayerNorm
        |
     ┌──┴──┐
text_head    image_head
[b,s,V]     [b,s,P]
```

## Dual Loss

- **Text tokens**: cross-entropy against next-token targets
- **Image patches**: MSE between predicted and target noise/velocity
- Total: `text_weight * L_CE + image_weight * L_MSE`

## Usage

    model = Transfusion.build(
      embed_dim: 64,
      hidden_size: 256,
      num_heads: 8,
      num_layers: 6,
      vocab_size: 32_000,
      patch_dim: 64
    )

    # Build the mixed attention mask for a 20-token text + 16-patch image
    mask = Transfusion.build_mixed_mask(20, 16)

    # Compute training loss
    loss = Transfusion.transfusion_loss(text_logits, image_pred, %{
      text_targets:  token_ids,      # [batch, seq_len] integer indices
      image_targets: noise_targets,  # [batch, seq_len, patch_dim]
      text_mask:     text_positions, # [batch, seq_len] float, 1 at text positions
      image_mask:    image_positions # [batch, seq_len] float, 1 at image positions
    })

## References

- Paper: "Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model"
- Authors: Chunting Zhou et al., Meta (2024)
- arXiv: https://arxiv.org/abs/2408.11039

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:patch_dim, pos_integer()}
  | {:vocab_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Transfusion model for joint text + image generation.

## Options

  - `:embed_dim` - Input embedding dimension (required)
  - `:hidden_size` - Transformer hidden dimension (default: 256)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:num_layers` - Number of transformer blocks (default: 6)
  - `:vocab_size` - Text vocabulary size for CE head (default: 32_000)
  - `:patch_dim` - Image patch feature dimension for diffusion head (default: 64)
  - `:dropout` - Dropout rate (default: 0.0)

## Returns

  `Axon.container(%{text_logits: [batch, seq, vocab_size], image_pred: [batch, seq, patch_dim]})`

# `build_mixed_mask`

```elixir
@spec build_mixed_mask(pos_integer(), pos_integer(), keyword()) :: Nx.Tensor.t()
```

Build the Transfusion mixed attention mask for a text+image sequence.

Produces a boolean matrix of shape `[text_len + image_len, text_len + image_len]`
where `true` means "allow attention":

- Text queries see all preceding positions (causal).
- Image queries see all other image patches (bidirectional) and all preceding text.

Combined rule: `mask[i, j] = (j ≤ i) OR (image[i] AND image[j])`

## Parameters

  - `text_len` - Number of text token positions
  - `image_len` - Number of image patch positions (appended after text)

## Options

Currently unused; reserved for future per-image-region masks.

## Returns

  Boolean `Nx.Tensor.t()` of shape `[text_len + image_len, text_len + image_len]`.
  `true` = allowed, `false` = masked.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output hidden size of a Transfusion model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Approximate parameter count for a Transfusion model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Recommended defaults for a small Transfusion model.

# `transfusion_loss`

```elixir
@spec transfusion_loss(Nx.Tensor.t(), Nx.Tensor.t(), map(), keyword()) ::
  Nx.Tensor.t()
```

Compute the combined Transfusion training loss.

Combines cross-entropy on text positions with MSE on image positions,
each masked and averaged over only the relevant positions.

## Parameters

  - `text_logits` - `[batch, seq_len, vocab_size]` raw logits from text head
  - `image_pred` - `[batch, seq_len, patch_dim]` predicted noise/velocity
  - `targets` - Map with:
      - `:text_targets` — `[batch, seq_len]` integer token IDs (next-token labels)
      - `:image_targets` — `[batch, seq_len, patch_dim]` target denoised patches
      - `:text_mask` — `[batch, seq_len]` float 1.0 at text positions, 0.0 elsewhere
      - `:image_mask` — `[batch, seq_len]` float 1.0 at image positions, 0.0 elsewhere

## Options

  - `:text_weight` - Weight for CE text loss (default: 1.0)
  - `:image_weight` - Weight for MSE image loss (default: 1.0)

## Returns

  Scalar loss tensor: `text_weight * L_CE + image_weight * L_MSE`.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
