Edifice.Generative.CogVideoX (Edifice v0.2.0)

Copy Markdown View Source

CogVideoX: Text-to-Video Diffusion with Expert Transformer.

Implements the CogVideoX architecture from "CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer" (Yang et al., ZHIPU AI 2024). A state-of-the-art open-source video generation model combining 3D causal VAE compression with an expert transformer that routes text and video tokens to specialized FFN experts.

Key Innovations

1. 3D Causal VAE

Compresses video to latent space while preserving temporal causality:

  • Spatial compression: 8× downsample in H and W via 2D convolutions
  • Temporal compression: 4× downsample in time via causal 1D convolutions
  • Causal constraint: Frame t can only depend on frames ≤ t (enables streaming)
  • 3D convolutions: Factorized as spatial 2D conv + temporal 1D conv
Video [B, T, C, H, W]    Latent [B, T/4, C', H/8, W/8]
     (49 frames)              (13 latent frames)

2. Expert Transformer (DiT + MoE FFN)

Full 3D attention over space-time with modality-specific experts:

  • 3D patchification: Video latent → space-time tokens
  • Full 3D attention: Every token attends to all other tokens (expensive but high quality)
  • Expert FFN routing: Text tokens → text experts, video tokens → video experts
  • RoPE3D: Rotary position embeddings over (time, height, width)
Text tokens 
                 Shared 3D Attention  Expert FFN  ...
Video tokens                               
                     (all-to-all)          (routed by modality)

Architecture Details

Input: Video [batch, frames, 3, H, W] + Text embeddings [batch, text_len, dim]
       |
       v
+---------------------------+
| 3D Causal VAE Encoder     |  Compress video to latent space
+---------------------------+
       |
       v
Latent [batch, T', latent_dim, H', W']
       |
       v
+---------------------------+
| 3D Patchify + RoPE3D      |  Convert to space-time tokens
+---------------------------+
       |
       v
Concat with text tokens  [batch, text_len + video_tokens, hidden]
       |
       v
+---------------------------+
| Expert Transformer Block  |  × num_layers
|   Full 3D Self-Attention |
|   Expert FFN (text/video)|
+---------------------------+
       |
       v
| Unpatchify + VAE Decode   |
       |
       v
Output: Generated Video [batch, frames, 3, H, W]

Usage

# Build the 3D causal VAE
vae = CogVideoX.build_vae(
  in_channels: 3,
  latent_channels: 16,
  num_frames: 49
)

# Build the expert transformer
transformer = CogVideoX.build_transformer(
  patch_size: [1, 2, 2],
  hidden_size: 1920,
  num_heads: 48,
  num_layers: 42,
  num_frames: 49
)

# Or build the full pipeline
model = CogVideoX.build(
  hidden_size: 1920,
  num_heads: 48,
  num_layers: 42
)

References

Summary

Functions

Build the full CogVideoX pipeline (VAE + Expert Transformer).

Build the Expert Transformer for video generation.

Build the 3D Causal VAE encoder-decoder for video compression.

Decode latent to video using the VAE decoder.

Encode video to latent space using the VAE encoder.

Approximate parameter count for CogVideoX transformer.

Get recommended defaults for CogVideoX.

Compute 3D RoPE frequencies for position encoding.

Types

transformer_opt()

@type transformer_opt() ::
  {:hidden_size, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_frames, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_text_experts, pos_integer()}
  | {:num_video_experts, pos_integer()}
  | {:patch_size, [pos_integer()]}
  | {:text_hidden_size, pos_integer()}

Options for build_transformer/1.

vae_opt()

@type vae_opt() ::
  {:base_channels, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:latent_channels, pos_integer()}
  | {:num_frames, pos_integer()}
  | {:spatial_downsample, pos_integer()}
  | {:temporal_downsample, pos_integer()}

Options for build_vae/1.

Functions

build(opts \\ [])

@spec build(keyword()) :: Axon.t()

Build the full CogVideoX pipeline (VAE + Expert Transformer).

Options

All options from build_vae/1 and build_transformer/1, plus:

  • :latent_channels - Latent space channels (default: 16)

Returns

An Axon model for the full video generation pipeline.

build_transformer(opts \\ [])

@spec build_transformer([transformer_opt()]) :: Axon.t()

Build the Expert Transformer for video generation.

Uses DiT-style architecture with full 3D attention and modality-specific FFN experts (text tokens → text experts, video tokens → video experts).

Options

  • :patch_size - Patch size as [t, h, w] (default: [1, 2, 2])
  • :hidden_size - Transformer hidden dimension (default: 1920)
  • :num_heads - Number of attention heads (default: 48)
  • :num_layers - Number of transformer layers (default: 42)
  • :num_frames - Number of latent frames (default: 49)
  • :text_hidden_size - Text embedding dimension (default: 4096)
  • :mlp_ratio - MLP expansion ratio (default: 4.0)
  • :num_text_experts - Number of text FFN experts (default: 1)
  • :num_video_experts - Number of video FFN experts (default: 1)

Returns

An Axon model that takes video latents + text embeddings and outputs denoised latents.

build_vae(opts \\ [])

@spec build_vae([vae_opt()]) :: {Axon.t(), Axon.t()}

Build the 3D Causal VAE encoder-decoder for video compression.

The VAE uses factorized 3D convolutions (2D spatial + 1D temporal) with causal temporal convolutions to ensure frame t only depends on frames ≤ t.

Options

  • :in_channels - Input video channels (default: 3)
  • :latent_channels - Latent space channels (default: 16)
  • :num_frames - Number of input frames (default: 49)
  • :spatial_downsample - Spatial downsampling factor (default: 8)
  • :temporal_downsample - Temporal downsampling factor (default: 4)
  • :base_channels - Base channel count for encoder (default: 128)

Returns

A tuple {encoder, decoder} where:

  • encoder: Video [B, T, C, H, W] → Latent [B, T', C', H', W']
  • decoder: Latent → Reconstructed video

decode_latent(decoder, params, latent)

@spec decode_latent(Axon.t(), map(), Nx.Tensor.t()) :: Nx.Tensor.t()

Decode latent to video using the VAE decoder.

Parameters

  • decoder - VAE decoder model
  • params - Decoder parameters
  • latent - Latent tensor

Returns

Reconstructed video tensor [batch, frames, channels, height, width]

encode_video(encoder, params, video)

@spec encode_video(Axon.t(), map(), Nx.Tensor.t()) :: Nx.Tensor.t()

Encode video to latent space using the VAE encoder.

Parameters

  • encoder - VAE encoder model
  • params - Encoder parameters
  • video - Video tensor [batch, frames, channels, height, width]

Returns

Latent tensor [batch, latent_frames, latent_channels, h', w']

param_count(opts \\ [])

@spec param_count(keyword()) :: non_neg_integer()

Approximate parameter count for CogVideoX transformer.

rope3d_freqs(time_dim, height_dim, width_dim, opts \\ [])

@spec rope3d_freqs(pos_integer(), pos_integer(), pos_integer(), keyword()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Compute 3D RoPE frequencies for position encoding.

Parameters

  • time_dim - Temporal dimension size
  • height_dim - Height dimension size
  • width_dim - Width dimension size
  • opts - Options including :hidden_size and :num_heads

Returns

Tuple of {sin_freqs, cos_freqs} tensors for RoPE application.