CogVideoX: Text-to-Video Diffusion with Expert Transformer.
Implements the CogVideoX architecture from "CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer" (Yang et al., ZHIPU AI 2024). A state-of-the-art open-source video generation model combining 3D causal VAE compression with an expert transformer that routes text and video tokens to specialized FFN experts.
Key Innovations
1. 3D Causal VAE
Compresses video to latent space while preserving temporal causality:
- Spatial compression: 8× downsample in H and W via 2D convolutions
- Temporal compression: 4× downsample in time via causal 1D convolutions
- Causal constraint: Frame t can only depend on frames ≤ t (enables streaming)
- 3D convolutions: Factorized as spatial 2D conv + temporal 1D conv
Video [B, T, C, H, W] → Latent [B, T/4, C', H/8, W/8]
(49 frames) (13 latent frames)2. Expert Transformer (DiT + MoE FFN)
Full 3D attention over space-time with modality-specific experts:
- 3D patchification: Video latent → space-time tokens
- Full 3D attention: Every token attends to all other tokens (expensive but high quality)
- Expert FFN routing: Text tokens → text experts, video tokens → video experts
- RoPE3D: Rotary position embeddings over (time, height, width)
Text tokens ────┐
├──→ Shared 3D Attention ──→ Expert FFN ──→ ...
Video tokens ───┘ │ │
(all-to-all) (routed by modality)Architecture Details
Input: Video [batch, frames, 3, H, W] + Text embeddings [batch, text_len, dim]
|
v
+---------------------------+
| 3D Causal VAE Encoder | Compress video to latent space
+---------------------------+
|
v
Latent [batch, T', latent_dim, H', W']
|
v
+---------------------------+
| 3D Patchify + RoPE3D | Convert to space-time tokens
+---------------------------+
|
v
Concat with text tokens → [batch, text_len + video_tokens, hidden]
|
v
+---------------------------+
| Expert Transformer Block | × num_layers
| • Full 3D Self-Attention |
| • Expert FFN (text/video)|
+---------------------------+
|
v
| Unpatchify + VAE Decode |
|
v
Output: Generated Video [batch, frames, 3, H, W]Usage
# Build the 3D causal VAE
vae = CogVideoX.build_vae(
in_channels: 3,
latent_channels: 16,
num_frames: 49
)
# Build the expert transformer
transformer = CogVideoX.build_transformer(
patch_size: [1, 2, 2],
hidden_size: 1920,
num_heads: 48,
num_layers: 42,
num_frames: 49
)
# Or build the full pipeline
model = CogVideoX.build(
hidden_size: 1920,
num_heads: 48,
num_layers: 42
)References
- Paper: "CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer"
- Authors: Yang et al., ZHIPU AI
- Year: 2024
- Code: https://github.com/THUDM/CogVideo
Summary
Functions
Build the full CogVideoX pipeline (VAE + Expert Transformer).
Build the Expert Transformer for video generation.
Build the 3D Causal VAE encoder-decoder for video compression.
Decode latent to video using the VAE decoder.
Encode video to latent space using the VAE encoder.
Approximate parameter count for CogVideoX transformer.
Get recommended defaults for CogVideoX.
Compute 3D RoPE frequencies for position encoding.
Types
@type transformer_opt() :: {:hidden_size, pos_integer()} | {:mlp_ratio, float()} | {:num_frames, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:num_text_experts, pos_integer()} | {:num_video_experts, pos_integer()} | {:patch_size, [pos_integer()]} | {:text_hidden_size, pos_integer()}
Options for build_transformer/1.
@type vae_opt() :: {:base_channels, pos_integer()} | {:in_channels, pos_integer()} | {:latent_channels, pos_integer()} | {:num_frames, pos_integer()} | {:spatial_downsample, pos_integer()} | {:temporal_downsample, pos_integer()}
Options for build_vae/1.
Functions
Build the full CogVideoX pipeline (VAE + Expert Transformer).
Options
All options from build_vae/1 and build_transformer/1, plus:
:latent_channels- Latent space channels (default: 16)
Returns
An Axon model for the full video generation pipeline.
@spec build_transformer([transformer_opt()]) :: Axon.t()
Build the Expert Transformer for video generation.
Uses DiT-style architecture with full 3D attention and modality-specific FFN experts (text tokens → text experts, video tokens → video experts).
Options
:patch_size- Patch size as [t, h, w] (default: [1, 2, 2]):hidden_size- Transformer hidden dimension (default: 1920):num_heads- Number of attention heads (default: 48):num_layers- Number of transformer layers (default: 42):num_frames- Number of latent frames (default: 49):text_hidden_size- Text embedding dimension (default: 4096):mlp_ratio- MLP expansion ratio (default: 4.0):num_text_experts- Number of text FFN experts (default: 1):num_video_experts- Number of video FFN experts (default: 1)
Returns
An Axon model that takes video latents + text embeddings and outputs denoised latents.
Build the 3D Causal VAE encoder-decoder for video compression.
The VAE uses factorized 3D convolutions (2D spatial + 1D temporal) with causal temporal convolutions to ensure frame t only depends on frames ≤ t.
Options
:in_channels- Input video channels (default: 3):latent_channels- Latent space channels (default: 16):num_frames- Number of input frames (default: 49):spatial_downsample- Spatial downsampling factor (default: 8):temporal_downsample- Temporal downsampling factor (default: 4):base_channels- Base channel count for encoder (default: 128)
Returns
A tuple {encoder, decoder} where:
encoder: Video [B, T, C, H, W] → Latent [B, T', C', H', W']decoder: Latent → Reconstructed video
@spec decode_latent(Axon.t(), map(), Nx.Tensor.t()) :: Nx.Tensor.t()
Decode latent to video using the VAE decoder.
Parameters
decoder- VAE decoder modelparams- Decoder parameterslatent- Latent tensor
Returns
Reconstructed video tensor [batch, frames, channels, height, width]
@spec encode_video(Axon.t(), map(), Nx.Tensor.t()) :: Nx.Tensor.t()
Encode video to latent space using the VAE encoder.
Parameters
encoder- VAE encoder modelparams- Encoder parametersvideo- Video tensor [batch, frames, channels, height, width]
Returns
Latent tensor [batch, latent_frames, latent_channels, h', w']
@spec param_count(keyword()) :: non_neg_integer()
Approximate parameter count for CogVideoX transformer.
@spec recommended_defaults() :: keyword()
Get recommended defaults for CogVideoX.
@spec rope3d_freqs(pos_integer(), pos_integer(), pos_integer(), keyword()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Compute 3D RoPE frequencies for position encoding.
Parameters
time_dim- Temporal dimension sizeheight_dim- Height dimension sizewidth_dim- Width dimension sizeopts- Options including:hidden_sizeand:num_heads
Returns
Tuple of {sin_freqs, cos_freqs} tensors for RoPE application.