Edifice.Generative.TRELLIS (Edifice v0.2.0)

Copy Markdown View Source

TRELLIS: Structured 3D Latents for Scalable 3D Generation.

Implements the TRELLIS architecture from "TRELLIS: Structured 3D Latents for Scalable and Versatile 3D Generation" (Xiang et al., Microsoft Research 2024). A unified framework for high-quality 3D asset generation using sparse structured latent representations and rectified flow.

Key Innovations

1. Sparse Structured Latents (SLAT)

Represents 3D content as a sparse voxel grid where only occupied voxels store features:

  • Sparse representation: Only N occupied voxels (vs N³ dense grid)
  • Per-voxel features: Position (x,y,z) + local feature vector
  • Memory efficient: Enables high-resolution 3D at tractable cost
Dense 64³ grid = 262,144 voxels (most empty)
Sparse SLAT    = ~5,000 occupied voxels (typical)

2. Sparse Transformer

Attention mechanism designed for sparse 3D data:

  • 3D windowed attention: Local attention within spatial windows
  • Sparse convolutions: Feature propagation between nearby voxels
  • Sparse cross-attention: Voxels attend to text/image conditioning
Query voxel at (x,y,z) attends to:
- All voxels within window of size W centered at (x,y,z)
- Conditioning tokens (text or image features)

3. Rectified Flow

Simpler, faster alternative to DDPM diffusion:

  • Straight-line paths: x_t = t·x_1 + (1-t)·x_0 (linear interpolation)
  • Velocity prediction: Model predicts v = x_1 - x_0
  • Few-step sampling: Only 10-20 steps (vs 1000 for DDPM)
DDPM: Complex curved trajectories, 1000 steps
Rectified Flow: Straight lines, 10-20 steps

Architecture

Input: Text/Image conditioning
       |
       v
+---------------------------+
| Condition Encoder         |  (CLIP or similar)
+---------------------------+
       |
       v
+---------------------------+
| Sparse Transformer        |  × num_layers
|   Sparse Self-Attention  |  (3D windowed)
|   Sparse Cross-Attention |  (to conditioning)
|   Sparse FFN             |
+---------------------------+
       |
       v
+---------------------------+
| Rectified Flow Denoising  |  (10-20 steps)
+---------------------------+
       |
       v
+---------------------------+
| Decode SLAT  3D Output   |  (Gaussian splats, mesh, or NeRF)
+---------------------------+
       |
       v
Output: 3D asset (splats/mesh/radiance field)

Usage

# Build TRELLIS model
model = TRELLIS.build(
  voxel_resolution: 64,
  feature_dim: 32,
  num_layers: 12,
  num_heads: 8
)

# Sparse attention over occupied voxels
attended = TRELLIS.sparse_attention(
  sparse_features,
  positions,
  window_size: 8
)

# Single rectified flow step
x_t_minus_1 = TRELLIS.rectified_flow_step(
  model, x_t, t, conditioning
)

# Full generation
output = TRELLIS.generate(
  model, params, conditioning,
  num_steps: 20
)

Supported Output Formats

  • 3D Gaussian Splatting: Fast, high-quality rendering
  • Mesh extraction: Via marching cubes from density field
  • Radiance field: NeRF-style volumetric representation

References

  • Paper: "TRELLIS: Structured 3D Latents for Scalable and Versatile 3D Generation"
  • Authors: Xiang et al., Microsoft Research
  • Year: 2024
  • Project: https://trellis3d.github.io/

Summary

Types

Options for build/1.

Functions

Build the TRELLIS model for 3D generation.

Decode Sparse Structured Latent back to dense representation or output format.

Encode a dense voxel grid to Sparse Structured Latent (SLAT) representation.

Generate 3D content using rectified flow sampling.

Get the output feature dimension.

Approximate parameter count for TRELLIS model.

Get recommended defaults for TRELLIS.

Compute sparse windowed 3D attention over occupied voxels.

Types

build_opt()

@type build_opt() ::
  {:condition_dim, pos_integer()}
  | {:feature_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:max_voxels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:voxel_resolution, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build the TRELLIS model for 3D generation.

Options

  • :voxel_resolution - Resolution of voxel grid (default: 64)
  • :feature_dim - Per-voxel feature dimension (default: 32)
  • :hidden_size - Transformer hidden dimension (default: 512)
  • :num_layers - Number of sparse transformer layers (default: 12)
  • :num_heads - Number of attention heads (default: 8)
  • :window_size - Size of local attention window (default: 8)
  • :condition_dim - Conditioning vector dimension (default: 768)
  • :mlp_ratio - MLP expansion ratio (default: 4.0)
  • :max_voxels - Maximum number of occupied voxels (default: 8192)

Returns

An Axon model that takes sparse voxel features + conditioning and outputs denoised sparse features.

decode_from_slat(sparse_latent, opts \\ [])

@spec decode_from_slat(
  map(),
  keyword()
) :: Nx.Tensor.t() | map()

Decode Sparse Structured Latent back to dense representation or output format.

Parameters

  • sparse_latent - Map from encode_to_slat/2 or model output
  • opts - Options including:
    • :output_format - :dense, :gaussian_splats, or :mesh (default: :dense)
    • :resolution - Output resolution for dense format (default: 64)

Returns

Decoded 3D representation in requested format.

encode_to_slat(voxel_grid, opts \\ [])

@spec encode_to_slat(
  Nx.Tensor.t(),
  keyword()
) :: map()

Encode a dense voxel grid to Sparse Structured Latent (SLAT) representation.

Parameters

  • voxel_grid - Dense voxel grid [batch, resolution, resolution, resolution, features]
               or occupancy grid [batch, res, res, res]
  • opts - Options including :threshold for occupancy detection

Returns

A map with:

  • :features - Sparse features [batch, num_occupied, feature_dim]
  • :positions - Voxel positions [batch, num_occupied, 3]
  • :mask - Occupancy mask [batch, num_occupied]

generate(model, params, conditioning, opts \\ [])

@spec generate(Axon.t(), map(), Nx.Tensor.t(), keyword()) :: map()

Generate 3D content using rectified flow sampling.

Parameters

  • model - TRELLIS model
  • params - Model parameters
  • conditioning - Conditioning tensor [batch, cond_len, cond_dim]
  • opts - Options:
    • :num_steps - Number of denoising steps (default: 20)
    • :max_voxels - Maximum voxels in output (default: 8192)
    • :feature_dim - Feature dimension (default: 32)
    • :voxel_resolution - Resolution for position initialization (default: 64)

Returns

Generated sparse latent map ready for decoding.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output feature dimension.

param_count(opts \\ [])

@spec param_count(keyword()) :: non_neg_integer()

Approximate parameter count for TRELLIS model.

rectified_flow_step(model, params, x_t, t, conditioning, opts \\ [])

@spec rectified_flow_step(
  Axon.t(),
  map(),
  map(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: map()

Perform one rectified flow denoising step.

Rectified flow uses straight-line interpolation:

  • Forward: x_t = t x_1 + (1-t) x_0 (where x_0 is noise, x_1 is data)
  • Model predicts velocity: v = x_1 - x_0
  • Update: x_{t-dt} = x_t + dt * v

Parameters

  • model - TRELLIS model
  • params - Model parameters
  • x_t - Current noisy sparse latent (map with :features, :positions, :mask)
  • t - Current timestep [batch] in [0, 1]
  • conditioning - Conditioning tensor [batch, cond_len, cond_dim]
  • opts - Options:
    • :dt - Step size (default: computed from num_steps)
    • :num_steps - Total steps for dt calculation (default: 20)

Returns

Denoised sparse latent at t - dt.

sparse_attention(sparse_features, positions, opts \\ [])

@spec sparse_attention(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute sparse windowed 3D attention over occupied voxels.

Parameters

  • sparse_features - Voxel features [batch, num_voxels, feature_dim]
  • positions - Voxel positions [batch, num_voxels, 3]
  • opts - Options:
    • :window_size - Attention window size (default: 8)
    • :num_heads - Number of attention heads (default: 8)
    • :mask - Occupancy mask [batch, num_voxels] (optional)

Returns

Attended features [batch, num_voxels, feature_dim]