Linear DiT / SANA: Diffusion Transformer with Linear Attention.
Implements DiT (Diffusion Transformer) architecture with linear attention replacing the quadratic softmax attention. This achieves comparable image quality at dramatically reduced computational cost.
Key Innovation: Linear Attention in Diffusion
Standard DiT uses O(N²) softmax attention, which becomes prohibitive for high-resolution images. Linear DiT replaces this with O(N) linear attention using kernel feature maps, enabling:
- 100x speedup for high-resolution generation
- Same quality as quadratic DiT
- Scalable to 4K+ resolution images
Architecture
Input [batch, num_patches, patch_dim]
|
v
+---------------------------+
| Patchify + Position Embed |
+---------------------------+
|
v
+---------------------------+
| Linear DiT Block x depth |
| AdaLN-Zero(condition) |
| Linear Attention | <- O(N) instead of O(N²)
| Residual |
| AdaLN-Zero(condition) |
| MLP |
| Residual |
+---------------------------+
|
v
| Final AdaLN + Linear |
|
v
Output [batch, num_patches, patch_dim]Linear Attention Mechanism
Standard: Attn(Q,K,V) = softmax(QK^T/sqrt(d)) * V [O(N²)]
Linear: Attn(Q,K,V) = phi(Q) * (phi(K)^T * V) / (phi(Q) * sum(phi(K))) [O(N)]
Where phi(x) = ELU(x) + 1 ensures non-negative attention weights.
Usage
model = LinearDiT.build(
input_dim: 64,
hidden_size: 512,
num_layers: 12,
num_heads: 8,
patch_size: 2
)References
- SANA: "Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer" (2024)
- DiT: "Scalable Diffusion Models with Transformers" (Peebles & Xie, 2023)
- Linear Attention: "Transformers are RNNs" (Katharopoulos et al., 2020)
Summary
Functions
Build a Linear DiT model for diffusion denoising with linear attention.
Build a single Linear DiT block with AdaLN-Zero conditioning and linear attention.
Get the output size of a Linear DiT model.
Calculate approximate parameter count for a Linear DiT model.
Get recommended defaults for Linear DiT / SANA.
Types
@type build_opt() :: {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:num_steps, pos_integer()} | {:patch_size, pos_integer()}
Options for build/1.
Functions
Build a Linear DiT model for diffusion denoising with linear attention.
Options
:input_dim- Input/output feature dimension (required):hidden_size- Transformer hidden dimension (default: 512):num_layers- Number of DiT blocks (default: 12):num_heads- Number of attention heads (default: 8):mlp_ratio- MLP expansion ratio (default: 4.0):num_classes- Number of classes for conditioning (optional, nil = unconditional):num_steps- Number of diffusion timesteps (default: 1000):patch_size- Patch size for spatial inputs (default: 2)
Returns
An Axon model that predicts noise given (noisy_input, timestep, [class]).
Build a single Linear DiT block with AdaLN-Zero conditioning and linear attention.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Linear DiT model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a Linear DiT model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for Linear DiT / SANA.