Edifice.Generative.DiT (Edifice v0.2.0)

Copy Markdown View Source

DiT: Diffusion Transformer.

Implements the DiT architecture from "Scalable Diffusion Models with Transformers" (Peebles & Xie, ICCV 2023). Replaces the traditional U-Net backbone in diffusion models with a Transformer, using Adaptive Layer Normalization (AdaLN-Zero) for timestep and class conditioning.

Key Innovation: AdaLN-Zero Conditioning

Instead of cross-attention for conditioning (expensive), DiT modulates LayerNorm parameters based on the conditioning signal:

# Standard LayerNorm:
y = gamma * normalize(x) + beta

# AdaLN-Zero:
gamma, beta, alpha = MLP(condition)    # Learned modulation
y = gamma * normalize(x) + beta       # Modulated norm
y = alpha * y                         # Scale (initialized to zero)

Initializing alpha to zero means each DiT block starts as an identity function, enabling stable deep training.

Architecture

Input [batch, input_dim]
      |
      v
+--------------------------+
| Patchify + Position Embed|
+--------------------------+
      |
      v
+--------------------------+
| DiT Block x depth        |
|  AdaLN-Zero(cond)        |
|  Self-Attention          |
|  Residual                |
|  AdaLN-Zero(cond)        |
|  MLP                     |
|  Residual                |
+--------------------------+
      |
      v
| Final AdaLN + Linear    |
      |
      v
Output [batch, input_dim]  (predicted noise or v-prediction)

Conditioning

Timestep t -----> Sinusoidal Embed --> MLP --+
                                             |--> condition vector
Class label c --> Embedding ----------> MLP --+

Usage

model = DiT.build(
  input_dim: 64,
  hidden_size: 256,
  depth: 6,
  num_heads: 4
)

Reference

Summary

Types

Options for build/1.

Functions

Build a DiT model for diffusion denoising.

Build a single DiT block with AdaLN-Zero conditioning.

Get the output size of a DiT model.

Calculate approximate parameter count for a DiT model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:depth, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:num_steps, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a DiT model for diffusion denoising.

Options

  • :input_dim - Input/output feature dimension (required)
  • :hidden_size - Transformer hidden dimension (default: 256)
  • :depth - Number of DiT blocks (default: 6)
  • :num_heads - Number of attention heads (default: 4)
  • :mlp_ratio - MLP expansion ratio (default: 4.0)
  • :num_classes - Number of classes for conditioning (optional, nil = unconditional)
  • :num_steps - Number of diffusion timesteps (default: 1000)

Returns

An Axon model that predicts noise given (noisy_input, timestep, [class]).

build_dit_block(input, condition, opts)

@spec build_dit_block(Axon.t(), Axon.t(), keyword()) :: Axon.t()

Build a single DiT block with AdaLN-Zero conditioning.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a DiT model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a DiT model.