TRELLIS: Structured 3D Latents for Scalable 3D Generation.
Implements the TRELLIS architecture from "TRELLIS: Structured 3D Latents for Scalable and Versatile 3D Generation" (Xiang et al., Microsoft Research 2024). A unified framework for high-quality 3D asset generation using sparse structured latent representations and rectified flow.
Key Innovations
1. Sparse Structured Latents (SLAT)
Represents 3D content as a sparse voxel grid where only occupied voxels store features:
- Sparse representation: Only N occupied voxels (vs N³ dense grid)
- Per-voxel features: Position (x,y,z) + local feature vector
- Memory efficient: Enables high-resolution 3D at tractable cost
Dense 64³ grid = 262,144 voxels (most empty)
Sparse SLAT = ~5,000 occupied voxels (typical)2. Sparse Transformer
Attention mechanism designed for sparse 3D data:
- 3D windowed attention: Local attention within spatial windows
- Sparse convolutions: Feature propagation between nearby voxels
- Sparse cross-attention: Voxels attend to text/image conditioning
Query voxel at (x,y,z) attends to:
- All voxels within window of size W centered at (x,y,z)
- Conditioning tokens (text or image features)3. Rectified Flow
Simpler, faster alternative to DDPM diffusion:
- Straight-line paths: x_t = t·x_1 + (1-t)·x_0 (linear interpolation)
- Velocity prediction: Model predicts v = x_1 - x_0
- Few-step sampling: Only 10-20 steps (vs 1000 for DDPM)
DDPM: Complex curved trajectories, 1000 steps
Rectified Flow: Straight lines, 10-20 stepsArchitecture
Input: Text/Image conditioning
|
v
+---------------------------+
| Condition Encoder | (CLIP or similar)
+---------------------------+
|
v
+---------------------------+
| Sparse Transformer | × num_layers
| • Sparse Self-Attention | (3D windowed)
| • Sparse Cross-Attention | (to conditioning)
| • Sparse FFN |
+---------------------------+
|
v
+---------------------------+
| Rectified Flow Denoising | (10-20 steps)
+---------------------------+
|
v
+---------------------------+
| Decode SLAT → 3D Output | (Gaussian splats, mesh, or NeRF)
+---------------------------+
|
v
Output: 3D asset (splats/mesh/radiance field)Usage
# Build TRELLIS model
model = TRELLIS.build(
voxel_resolution: 64,
feature_dim: 32,
num_layers: 12,
num_heads: 8
)
# Sparse attention over occupied voxels
attended = TRELLIS.sparse_attention(
sparse_features,
positions,
window_size: 8
)
# Single rectified flow step
x_t_minus_1 = TRELLIS.rectified_flow_step(
model, x_t, t, conditioning
)
# Full generation
output = TRELLIS.generate(
model, params, conditioning,
num_steps: 20
)Supported Output Formats
- 3D Gaussian Splatting: Fast, high-quality rendering
- Mesh extraction: Via marching cubes from density field
- Radiance field: NeRF-style volumetric representation
References
- Paper: "TRELLIS: Structured 3D Latents for Scalable and Versatile 3D Generation"
- Authors: Xiang et al., Microsoft Research
- Year: 2024
- Project: https://trellis3d.github.io/
Summary
Functions
Build the TRELLIS model for 3D generation.
Decode Sparse Structured Latent back to dense representation or output format.
Encode a dense voxel grid to Sparse Structured Latent (SLAT) representation.
Generate 3D content using rectified flow sampling.
Get the output feature dimension.
Approximate parameter count for TRELLIS model.
Get recommended defaults for TRELLIS.
Perform one rectified flow denoising step.
Compute sparse windowed 3D attention over occupied voxels.
Types
@type build_opt() :: {:condition_dim, pos_integer()} | {:feature_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:max_voxels, pos_integer()} | {:mlp_ratio, float()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:voxel_resolution, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build the TRELLIS model for 3D generation.
Options
:voxel_resolution- Resolution of voxel grid (default: 64):feature_dim- Per-voxel feature dimension (default: 32):hidden_size- Transformer hidden dimension (default: 512):num_layers- Number of sparse transformer layers (default: 12):num_heads- Number of attention heads (default: 8):window_size- Size of local attention window (default: 8):condition_dim- Conditioning vector dimension (default: 768):mlp_ratio- MLP expansion ratio (default: 4.0):max_voxels- Maximum number of occupied voxels (default: 8192)
Returns
An Axon model that takes sparse voxel features + conditioning and outputs denoised sparse features.
@spec decode_from_slat( map(), keyword() ) :: Nx.Tensor.t() | map()
Decode Sparse Structured Latent back to dense representation or output format.
Parameters
sparse_latent- Map fromencode_to_slat/2or model outputopts- Options including::output_format-:dense,:gaussian_splats, or:mesh(default::dense):resolution- Output resolution for dense format (default: 64)
Returns
Decoded 3D representation in requested format.
@spec encode_to_slat( Nx.Tensor.t(), keyword() ) :: map()
Encode a dense voxel grid to Sparse Structured Latent (SLAT) representation.
Parameters
voxel_grid- Dense voxel grid [batch, resolution, resolution, resolution, features]or occupancy grid [batch, res, res, res]opts- Options including:thresholdfor occupancy detection
Returns
A map with:
:features- Sparse features [batch, num_occupied, feature_dim]:positions- Voxel positions [batch, num_occupied, 3]:mask- Occupancy mask [batch, num_occupied]
@spec generate(Axon.t(), map(), Nx.Tensor.t(), keyword()) :: map()
Generate 3D content using rectified flow sampling.
Parameters
model- TRELLIS modelparams- Model parametersconditioning- Conditioning tensor [batch, cond_len, cond_dim]opts- Options::num_steps- Number of denoising steps (default: 20):max_voxels- Maximum voxels in output (default: 8192):feature_dim- Feature dimension (default: 32):voxel_resolution- Resolution for position initialization (default: 64)
Returns
Generated sparse latent map ready for decoding.
@spec output_size(keyword()) :: pos_integer()
Get the output feature dimension.
@spec param_count(keyword()) :: non_neg_integer()
Approximate parameter count for TRELLIS model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for TRELLIS.
@spec rectified_flow_step( Axon.t(), map(), map(), Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: map()
Perform one rectified flow denoising step.
Rectified flow uses straight-line interpolation:
- Forward: x_t = t x_1 + (1-t) x_0 (where x_0 is noise, x_1 is data)
- Model predicts velocity: v = x_1 - x_0
- Update: x_{t-dt} = x_t + dt * v
Parameters
model- TRELLIS modelparams- Model parametersx_t- Current noisy sparse latent (map with :features, :positions, :mask)t- Current timestep [batch] in [0, 1]conditioning- Conditioning tensor [batch, cond_len, cond_dim]opts- Options::dt- Step size (default: computed from num_steps):num_steps- Total steps for dt calculation (default: 20)
Returns
Denoised sparse latent at t - dt.
@spec sparse_attention(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute sparse windowed 3D attention over occupied voxels.
Parameters
sparse_features- Voxel features [batch, num_voxels, feature_dim]positions- Voxel positions [batch, num_voxels, 3]opts- Options::window_size- Attention window size (default: 8):num_heads- Number of attention heads (default: 8):mask- Occupancy mask [batch, num_voxels] (optional)
Returns
Attended features [batch, num_voxels, feature_dim]