DiT: Diffusion Transformer.
Implements the DiT architecture from "Scalable Diffusion Models with Transformers" (Peebles & Xie, ICCV 2023). Replaces the traditional U-Net backbone in diffusion models with a Transformer, using Adaptive Layer Normalization (AdaLN-Zero) for timestep and class conditioning.
Key Innovation: AdaLN-Zero Conditioning
Instead of cross-attention for conditioning (expensive), DiT modulates LayerNorm parameters based on the conditioning signal:
# Standard LayerNorm:
y = gamma * normalize(x) + beta
# AdaLN-Zero:
gamma, beta, alpha = MLP(condition) # Learned modulation
y = gamma * normalize(x) + beta # Modulated norm
y = alpha * y # Scale (initialized to zero)Initializing alpha to zero means each DiT block starts as an identity function, enabling stable deep training.
Architecture
Input [batch, input_dim]
|
v
+--------------------------+
| Patchify + Position Embed|
+--------------------------+
|
v
+--------------------------+
| DiT Block x depth |
| AdaLN-Zero(cond) |
| Self-Attention |
| Residual |
| AdaLN-Zero(cond) |
| MLP |
| Residual |
+--------------------------+
|
v
| Final AdaLN + Linear |
|
v
Output [batch, input_dim] (predicted noise or v-prediction)Conditioning
Timestep t -----> Sinusoidal Embed --> MLP --+
|--> condition vector
Class label c --> Embedding ----------> MLP --+Usage
model = DiT.build(
input_dim: 64,
hidden_size: 256,
depth: 6,
num_heads: 4
)Reference
- Paper: "Scalable Diffusion Models with Transformers"
- arXiv: https://arxiv.org/abs/2212.09748
Summary
Functions
Build a DiT model for diffusion denoising.
Build a single DiT block with AdaLN-Zero conditioning.
Get the output size of a DiT model.
Calculate approximate parameter count for a DiT model.
Get recommended defaults.
Types
@type build_opt() :: {:depth, pos_integer()} | {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, pos_integer()} | {:num_steps, pos_integer()}
Options for build/1.
Functions
Build a DiT model for diffusion denoising.
Options
:input_dim- Input/output feature dimension (required):hidden_size- Transformer hidden dimension (default: 256):depth- Number of DiT blocks (default: 6):num_heads- Number of attention heads (default: 4):mlp_ratio- MLP expansion ratio (default: 4.0):num_classes- Number of classes for conditioning (optional, nil = unconditional):num_steps- Number of diffusion timesteps (default: 1000)
Returns
An Axon model that predicts noise given (noisy_input, timestep, [class]).
Build a single DiT block with AdaLN-Zero conditioning.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a DiT model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a DiT model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.