Edifice.Generative.Transfusion (Edifice v0.2.0)

Copy Markdown View Source

Transfusion: Unified Autoregressive Text + Diffusion Image Generation.

A single transformer model that jointly handles discrete text tokens (autoregressive next-token prediction) and continuous image patches (denoising diffusion) in one shared backbone.

Key Innovation: Mixed Attention Mask

Text tokens and image patches share the same transformer layers, but attend with different masks:

  • Text positions (causal): each token sees only preceding tokens
  • Image positions (bidirectional within image): each patch sees all other patches in the same image, plus all preceding text context

Combined rule:

mask[i, j] = 1  if  j  i                        # causal for text
             OR  (image[i] AND image[j])            # bidir within image

Architecture

Inputs: sequence [batch, seq_len, embed_dim]  (text embeddings + image patches)
        modality_mask [batch, seq_len]          (0=text, 1=image)
        timestep [batch]                        (diffusion step for image)
        |
        v
Modality type embedding  (learnable TEXT / IMAGE vectors added to tokens)
        |
        v
Input projection    hidden_size
        |
        v
+----------------------------------------------+
|  Transfusion Block  ×  num_layers             |
|                                               |
|  Add time_embed at image positions            |
|  LayerNorm    Mixed Attention    Residual   |
|  LayerNorm    FFN (GELU)         Residual   |
+----------------------------------------------+
        |
Final LayerNorm
        |
     
text_head    image_head
[b,s,V]     [b,s,P]

Dual Loss

  • Text tokens: cross-entropy against next-token targets
  • Image patches: MSE between predicted and target noise/velocity
  • Total: text_weight * L_CE + image_weight * L_MSE

Usage

model = Transfusion.build(
  embed_dim: 64,
  hidden_size: 256,
  num_heads: 8,
  num_layers: 6,
  vocab_size: 32_000,
  patch_dim: 64
)

# Build the mixed attention mask for a 20-token text + 16-patch image
mask = Transfusion.build_mixed_mask(20, 16)

# Compute training loss
loss = Transfusion.transfusion_loss(text_logits, image_pred, %{
  text_targets:  token_ids,      # [batch, seq_len] integer indices
  image_targets: noise_targets,  # [batch, seq_len, patch_dim]
  text_mask:     text_positions, # [batch, seq_len] float, 1 at text positions
  image_mask:    image_positions # [batch, seq_len] float, 1 at image positions
})

References

  • Paper: "Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model"
  • Authors: Chunting Zhou et al., Meta (2024)
  • arXiv: https://arxiv.org/abs/2408.11039

Summary

Types

Options for build/1.

Functions

Build a Transfusion model for joint text + image generation.

Build the Transfusion mixed attention mask for a text+image sequence.

Get the output hidden size of a Transfusion model.

Approximate parameter count for a Transfusion model.

Recommended defaults for a small Transfusion model.

Compute the combined Transfusion training loss.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:patch_dim, pos_integer()}
  | {:vocab_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Transfusion model for joint text + image generation.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Transformer hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 8)
  • :num_layers - Number of transformer blocks (default: 6)
  • :vocab_size - Text vocabulary size for CE head (default: 32_000)
  • :patch_dim - Image patch feature dimension for diffusion head (default: 64)
  • :dropout - Dropout rate (default: 0.0)

Returns

Axon.container(%{text_logits: [batch, seq, vocab_size], image_pred: [batch, seq, patch_dim]})

build_mixed_mask(text_len, image_len, opts \\ [])

@spec build_mixed_mask(pos_integer(), pos_integer(), keyword()) :: Nx.Tensor.t()

Build the Transfusion mixed attention mask for a text+image sequence.

Produces a boolean matrix of shape [text_len + image_len, text_len + image_len] where true means "allow attention":

  • Text queries see all preceding positions (causal).
  • Image queries see all other image patches (bidirectional) and all preceding text.

Combined rule: mask[i, j] = (j ≤ i) OR (image[i] AND image[j])

Parameters

  • text_len - Number of text token positions
  • image_len - Number of image patch positions (appended after text)

Options

Currently unused; reserved for future per-image-region masks.

Returns

Boolean Nx.Tensor.t() of shape [text_len + image_len, text_len + image_len]. true = allowed, false = masked.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output hidden size of a Transfusion model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Approximate parameter count for a Transfusion model.

transfusion_loss(text_logits, image_pred, targets, opts \\ [])

@spec transfusion_loss(Nx.Tensor.t(), Nx.Tensor.t(), map(), keyword()) ::
  Nx.Tensor.t()

Compute the combined Transfusion training loss.

Combines cross-entropy on text positions with MSE on image positions, each masked and averaged over only the relevant positions.

Parameters

  • text_logits - [batch, seq_len, vocab_size] raw logits from text head
  • image_pred - [batch, seq_len, patch_dim] predicted noise/velocity
  • targets - Map with:
    • :text_targets[batch, seq_len] integer token IDs (next-token labels)
    • :image_targets[batch, seq_len, patch_dim] target denoised patches
    • :text_mask[batch, seq_len] float 1.0 at text positions, 0.0 elsewhere
    • :image_mask[batch, seq_len] float 1.0 at image positions, 0.0 elsewhere

Options

  • :text_weight - Weight for CE text loss (default: 1.0)
  • :image_weight - Weight for MSE image loss (default: 1.0)

Returns

Scalar loss tensor: text_weight * L_CE + image_weight * L_MSE.