MMDiT: Multimodal Diffusion Transformer.
Implements the MMDiT architecture from "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (Esser et al., 2024), used in Stable Diffusion 3 and FLUX.1. Replaces DiT's single-stream design with dual parallel streams (one per modality) connected via joint self-attention.
Key Innovation: Joint Attention (Not Cross-Attention)
Instead of cross-attention between modalities, MMDiT concatenates Q, K, V from both streams along the sequence dimension, runs a single standard self-attention, then splits the output back. This allows all four interaction directions (I2I, I2T, T2I, T2T) in every layer.
Image stream: img_q, img_k, img_v = img_attn(AdaLN(img))
Text stream: txt_q, txt_k, txt_v = txt_attn(AdaLN(txt))
Combined: q = cat([txt_q, img_q], dim=seq)
k = cat([txt_k, img_k], dim=seq)
v = cat([txt_v, img_v], dim=seq)
Joint attn: out = softmax(q @ k^T / sqrt(d)) @ v
Split back: txt_out = out[:, :txt_len]
img_out = out[:, txt_len:]Architecture
Image Latent [batch, img_tokens, img_dim] Text Embed [batch, txt_tokens, txt_dim]
| |
v v
[Image Projection] [Text Projection]
| |
v v
+-------------------------------------------------------------------+
| Double-Stream Block x depth |
| |
| img_stream: AdaLN(vec) -> QKV_img ---+ |
| txt_stream: AdaLN(vec) -> QKV_txt ---+-> Joint Attention |
| | |
| img_stream: gate * proj(img_attn) + residual -> AdaLN -> MLP |
| txt_stream: gate * proj(txt_attn) + residual -> AdaLN -> MLP |
+-------------------------------------------------------------------+
|
v
[Final Norm + Linear] -> output [batch, img_tokens, img_dim]Conditioning (AdaLN-Zero)
Each stream has separate modulation weights. The conditioning vector is:
vec = timestep_mlp(sinusoidal(t)) + pooled_text_mlp(pooled_text)
Each modulation produces 6 parameters per stream: (shift_attn, scale_attn, gate_attn, shift_mlp, scale_mlp, gate_mlp)
Usage
model = MMDiT.build(
img_dim: 16, # VAE latent channels (flattened patch dim)
txt_dim: 4096, # Text encoder hidden dim (e.g., T5-XXL)
hidden_size: 1536, # Joint hidden dimension
depth: 24, # Number of double-stream blocks
num_heads: 24,
img_tokens: 256, # Number of image patches
txt_tokens: 77 # Max text sequence length
)References
Summary
Functions
Build an MMDiT model for multimodal diffusion.
Get the output size of an MMDiT model.
Get recommended defaults for MMDiT.
Types
@type build_opt() :: {:cond_dim, pos_integer()} | {:depth, pos_integer()} | {:hidden_size, pos_integer()} | {:img_dim, pos_integer()} | {:img_tokens, pos_integer()} | {:mlp_ratio, pos_integer()} | {:num_heads, pos_integer()} | {:txt_dim, pos_integer()} | {:txt_tokens, pos_integer()}
Options for build/1.
Functions
Build an MMDiT model for multimodal diffusion.
Options
Modality dimensions (at least one pair required):
:img_dim- Image patch feature dimension (required):txt_dim- Text token feature dimension (required):img_tokens- Number of image tokens/patches (default: 64):txt_tokens- Max text tokens (default: 32)
Architecture:
:hidden_size- Joint hidden dimension (default: 768):depth- Number of double-stream blocks (default: 12):num_heads- Attention heads (default: 12):mlp_ratio- MLP expansion ratio (default: 4):cond_dim- Conditioning vector dimension (default: hidden_size)
Returns
An Axon model: (img_latent, txt_embed, timestep, pooled_text) -> denoised_img
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of an MMDiT model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for MMDiT.