# `Edifice.Generative.MMDiT`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/mmdit.ex#L1)

MMDiT: Multimodal Diffusion Transformer.

Implements the MMDiT architecture from "Scaling Rectified Flow Transformers
for High-Resolution Image Synthesis" (Esser et al., 2024), used in Stable
Diffusion 3 and FLUX.1. Replaces DiT's single-stream design with dual
parallel streams (one per modality) connected via joint self-attention.

## Key Innovation: Joint Attention (Not Cross-Attention)

Instead of cross-attention between modalities, MMDiT concatenates Q, K, V
from both streams along the sequence dimension, runs a single standard
self-attention, then splits the output back. This allows all four
interaction directions (I2I, I2T, T2I, T2T) in every layer.

```
Image stream:  img_q, img_k, img_v = img_attn(AdaLN(img))
Text stream:   txt_q, txt_k, txt_v = txt_attn(AdaLN(txt))

Combined:      q = cat([txt_q, img_q], dim=seq)
               k = cat([txt_k, img_k], dim=seq)
               v = cat([txt_v, img_v], dim=seq)

Joint attn:    out = softmax(q @ k^T / sqrt(d)) @ v

Split back:    txt_out = out[:, :txt_len]
               img_out = out[:, txt_len:]
```

## Architecture

```
Image Latent [batch, img_tokens, img_dim]     Text Embed [batch, txt_tokens, txt_dim]
      |                                              |
      v                                              v
[Image Projection]                            [Text Projection]
      |                                              |
      v                                              v
+-------------------------------------------------------------------+
|  Double-Stream Block x depth                                       |
|                                                                    |
|  img_stream:  AdaLN(vec) -> QKV_img ---+                          |
|  txt_stream:  AdaLN(vec) -> QKV_txt ---+-> Joint Attention        |
|                                        |                          |
|  img_stream:  gate * proj(img_attn) + residual -> AdaLN -> MLP   |
|  txt_stream:  gate * proj(txt_attn) + residual -> AdaLN -> MLP   |
+-------------------------------------------------------------------+
      |
      v
[Final Norm + Linear] -> output [batch, img_tokens, img_dim]
```

## Conditioning (AdaLN-Zero)

Each stream has separate modulation weights. The conditioning vector is:
`vec = timestep_mlp(sinusoidal(t)) + pooled_text_mlp(pooled_text)`

Each modulation produces 6 parameters per stream:
(shift_attn, scale_attn, gate_attn, shift_mlp, scale_mlp, gate_mlp)

## Usage

    model = MMDiT.build(
      img_dim: 16,           # VAE latent channels (flattened patch dim)
      txt_dim: 4096,         # Text encoder hidden dim (e.g., T5-XXL)
      hidden_size: 1536,     # Joint hidden dimension
      depth: 24,             # Number of double-stream blocks
      num_heads: 24,
      img_tokens: 256,       # Number of image patches
      txt_tokens: 77         # Max text sequence length
    )

## References
- SD3: https://arxiv.org/abs/2403.03206
- FLUX.1: https://github.com/black-forest-labs/flux

# `build_opt`

```elixir
@type build_opt() ::
  {:cond_dim, pos_integer()}
  | {:depth, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:img_dim, pos_integer()}
  | {:img_tokens, pos_integer()}
  | {:mlp_ratio, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:txt_dim, pos_integer()}
  | {:txt_tokens, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build an MMDiT model for multimodal diffusion.

## Options

**Modality dimensions (at least one pair required):**
  - `:img_dim` - Image patch feature dimension (required)
  - `:txt_dim` - Text token feature dimension (required)
  - `:img_tokens` - Number of image tokens/patches (default: 64)
  - `:txt_tokens` - Max text tokens (default: 32)

**Architecture:**
  - `:hidden_size` - Joint hidden dimension (default: 768)
  - `:depth` - Number of double-stream blocks (default: 12)
  - `:num_heads` - Attention heads (default: 12)
  - `:mlp_ratio` - MLP expansion ratio (default: 4)
  - `:cond_dim` - Conditioning vector dimension (default: hidden_size)

## Returns
  An Axon model: (img_latent, txt_embed, timestep, pooled_text) -> denoised_img

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of an MMDiT model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for MMDiT.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
