Transfusion: Unified Autoregressive Text + Diffusion Image Generation.
A single transformer model that jointly handles discrete text tokens (autoregressive next-token prediction) and continuous image patches (denoising diffusion) in one shared backbone.
Key Innovation: Mixed Attention Mask
Text tokens and image patches share the same transformer layers, but attend with different masks:
- Text positions (causal): each token sees only preceding tokens
- Image positions (bidirectional within image): each patch sees all other patches in the same image, plus all preceding text context
Combined rule:
mask[i, j] = 1 if j ≤ i # causal for text
OR (image[i] AND image[j]) # bidir within imageArchitecture
Inputs: sequence [batch, seq_len, embed_dim] (text embeddings + image patches)
modality_mask [batch, seq_len] (0=text, 1=image)
timestep [batch] (diffusion step for image)
|
v
Modality type embedding (learnable TEXT / IMAGE vectors added to tokens)
|
v
Input projection → hidden_size
|
v
+----------------------------------------------+
| Transfusion Block × num_layers |
| |
| Add time_embed at image positions |
| LayerNorm → Mixed Attention → Residual |
| LayerNorm → FFN (GELU) → Residual |
+----------------------------------------------+
|
Final LayerNorm
|
┌──┴──┐
text_head image_head
[b,s,V] [b,s,P]Dual Loss
- Text tokens: cross-entropy against next-token targets
- Image patches: MSE between predicted and target noise/velocity
- Total:
text_weight * L_CE + image_weight * L_MSE
Usage
model = Transfusion.build(
embed_dim: 64,
hidden_size: 256,
num_heads: 8,
num_layers: 6,
vocab_size: 32_000,
patch_dim: 64
)
# Build the mixed attention mask for a 20-token text + 16-patch image
mask = Transfusion.build_mixed_mask(20, 16)
# Compute training loss
loss = Transfusion.transfusion_loss(text_logits, image_pred, %{
text_targets: token_ids, # [batch, seq_len] integer indices
image_targets: noise_targets, # [batch, seq_len, patch_dim]
text_mask: text_positions, # [batch, seq_len] float, 1 at text positions
image_mask: image_positions # [batch, seq_len] float, 1 at image positions
})References
- Paper: "Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model"
- Authors: Chunting Zhou et al., Meta (2024)
- arXiv: https://arxiv.org/abs/2408.11039
Summary
Functions
Build a Transfusion model for joint text + image generation.
Build the Transfusion mixed attention mask for a text+image sequence.
Get the output hidden size of a Transfusion model.
Approximate parameter count for a Transfusion model.
Recommended defaults for a small Transfusion model.
Compute the combined Transfusion training loss.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:patch_dim, pos_integer()} | {:vocab_size, pos_integer()}
Options for build/1.
Functions
Build a Transfusion model for joint text + image generation.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Transformer hidden dimension (default: 256):num_heads- Number of attention heads (default: 8):num_layers- Number of transformer blocks (default: 6):vocab_size- Text vocabulary size for CE head (default: 32_000):patch_dim- Image patch feature dimension for diffusion head (default: 64):dropout- Dropout rate (default: 0.0)
Returns
Axon.container(%{text_logits: [batch, seq, vocab_size], image_pred: [batch, seq, patch_dim]})
@spec build_mixed_mask(pos_integer(), pos_integer(), keyword()) :: Nx.Tensor.t()
Build the Transfusion mixed attention mask for a text+image sequence.
Produces a boolean matrix of shape [text_len + image_len, text_len + image_len]
where true means "allow attention":
- Text queries see all preceding positions (causal).
- Image queries see all other image patches (bidirectional) and all preceding text.
Combined rule: mask[i, j] = (j ≤ i) OR (image[i] AND image[j])
Parameters
text_len- Number of text token positionsimage_len- Number of image patch positions (appended after text)
Options
Currently unused; reserved for future per-image-region masks.
Returns
Boolean Nx.Tensor.t() of shape [text_len + image_len, text_len + image_len].
true = allowed, false = masked.
@spec output_size(keyword()) :: pos_integer()
Get the output hidden size of a Transfusion model.
@spec param_count(keyword()) :: non_neg_integer()
Approximate parameter count for a Transfusion model.
@spec recommended_defaults() :: keyword()
Recommended defaults for a small Transfusion model.
@spec transfusion_loss(Nx.Tensor.t(), Nx.Tensor.t(), map(), keyword()) :: Nx.Tensor.t()
Compute the combined Transfusion training loss.
Combines cross-entropy on text positions with MSE on image positions, each masked and averaged over only the relevant positions.
Parameters
text_logits-[batch, seq_len, vocab_size]raw logits from text headimage_pred-[batch, seq_len, patch_dim]predicted noise/velocitytargets- Map with::text_targets—[batch, seq_len]integer token IDs (next-token labels):image_targets—[batch, seq_len, patch_dim]target denoised patches:text_mask—[batch, seq_len]float 1.0 at text positions, 0.0 elsewhere:image_mask—[batch, seq_len]float 1.0 at image positions, 0.0 elsewhere
Options
:text_weight- Weight for CE text loss (default: 1.0):image_weight- Weight for MSE image loss (default: 1.0)
Returns
Scalar loss tensor: text_weight * L_CE + image_weight * L_MSE.