Edifice.Generative.MMDiT (Edifice v0.2.0)

Copy Markdown View Source

MMDiT: Multimodal Diffusion Transformer.

Implements the MMDiT architecture from "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (Esser et al., 2024), used in Stable Diffusion 3 and FLUX.1. Replaces DiT's single-stream design with dual parallel streams (one per modality) connected via joint self-attention.

Key Innovation: Joint Attention (Not Cross-Attention)

Instead of cross-attention between modalities, MMDiT concatenates Q, K, V from both streams along the sequence dimension, runs a single standard self-attention, then splits the output back. This allows all four interaction directions (I2I, I2T, T2I, T2T) in every layer.

Image stream:  img_q, img_k, img_v = img_attn(AdaLN(img))
Text stream:   txt_q, txt_k, txt_v = txt_attn(AdaLN(txt))

Combined:      q = cat([txt_q, img_q], dim=seq)
               k = cat([txt_k, img_k], dim=seq)
               v = cat([txt_v, img_v], dim=seq)

Joint attn:    out = softmax(q @ k^T / sqrt(d)) @ v

Split back:    txt_out = out[:, :txt_len]
               img_out = out[:, txt_len:]

Architecture

Image Latent [batch, img_tokens, img_dim]     Text Embed [batch, txt_tokens, txt_dim]
      |                                              |
      v                                              v
[Image Projection]                            [Text Projection]
      |                                              |
      v                                              v
+-------------------------------------------------------------------+
|  Double-Stream Block x depth                                       |
|                                                                    |
|  img_stream:  AdaLN(vec) -> QKV_img ---+                          |
|  txt_stream:  AdaLN(vec) -> QKV_txt ---+-> Joint Attention        |
|                                        |                          |
|  img_stream:  gate * proj(img_attn) + residual -> AdaLN -> MLP   |
|  txt_stream:  gate * proj(txt_attn) + residual -> AdaLN -> MLP   |
+-------------------------------------------------------------------+
      |
      v
[Final Norm + Linear] -> output [batch, img_tokens, img_dim]

Conditioning (AdaLN-Zero)

Each stream has separate modulation weights. The conditioning vector is: vec = timestep_mlp(sinusoidal(t)) + pooled_text_mlp(pooled_text)

Each modulation produces 6 parameters per stream: (shift_attn, scale_attn, gate_attn, shift_mlp, scale_mlp, gate_mlp)

Usage

model = MMDiT.build(
  img_dim: 16,           # VAE latent channels (flattened patch dim)
  txt_dim: 4096,         # Text encoder hidden dim (e.g., T5-XXL)
  hidden_size: 1536,     # Joint hidden dimension
  depth: 24,             # Number of double-stream blocks
  num_heads: 24,
  img_tokens: 256,       # Number of image patches
  txt_tokens: 77         # Max text sequence length
)

References

Summary

Types

Options for build/1.

Functions

Build an MMDiT model for multimodal diffusion.

Get the output size of an MMDiT model.

Get recommended defaults for MMDiT.

Types

build_opt()

@type build_opt() ::
  {:cond_dim, pos_integer()}
  | {:depth, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:img_dim, pos_integer()}
  | {:img_tokens, pos_integer()}
  | {:mlp_ratio, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:txt_dim, pos_integer()}
  | {:txt_tokens, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an MMDiT model for multimodal diffusion.

Options

Modality dimensions (at least one pair required):

  • :img_dim - Image patch feature dimension (required)
  • :txt_dim - Text token feature dimension (required)
  • :img_tokens - Number of image tokens/patches (default: 64)
  • :txt_tokens - Max text tokens (default: 32)

Architecture:

  • :hidden_size - Joint hidden dimension (default: 768)
  • :depth - Number of double-stream blocks (default: 12)
  • :num_heads - Attention heads (default: 12)
  • :mlp_ratio - MLP expansion ratio (default: 4)
  • :cond_dim - Conditioning vector dimension (default: hidden_size)

Returns

An Axon model: (img_latent, txt_embed, timestep, pooled_text) -> denoised_img

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an MMDiT model.