# `Edifice.Generative.TRELLIS`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/trellis.ex#L1)

TRELLIS: Structured 3D Latents for Scalable 3D Generation.

Implements the TRELLIS architecture from "TRELLIS: Structured 3D Latents for
Scalable and Versatile 3D Generation" (Xiang et al., Microsoft Research 2024).
A unified framework for high-quality 3D asset generation using sparse structured
latent representations and rectified flow.

## Key Innovations

### 1. Sparse Structured Latents (SLAT)
Represents 3D content as a sparse voxel grid where only occupied voxels store features:
- **Sparse representation**: Only N occupied voxels (vs N³ dense grid)
- **Per-voxel features**: Position (x,y,z) + local feature vector
- **Memory efficient**: Enables high-resolution 3D at tractable cost

```
Dense 64³ grid = 262,144 voxels (most empty)
Sparse SLAT    = ~5,000 occupied voxels (typical)
```

### 2. Sparse Transformer
Attention mechanism designed for sparse 3D data:
- **3D windowed attention**: Local attention within spatial windows
- **Sparse convolutions**: Feature propagation between nearby voxels
- **Sparse cross-attention**: Voxels attend to text/image conditioning

```
Query voxel at (x,y,z) attends to:
- All voxels within window of size W centered at (x,y,z)
- Conditioning tokens (text or image features)
```

### 3. Rectified Flow
Simpler, faster alternative to DDPM diffusion:
- **Straight-line paths**: x_t = t·x_1 + (1-t)·x_0 (linear interpolation)
- **Velocity prediction**: Model predicts v = x_1 - x_0
- **Few-step sampling**: Only 10-20 steps (vs 1000 for DDPM)

```
DDPM: Complex curved trajectories, 1000 steps
Rectified Flow: Straight lines, 10-20 steps
```

## Architecture

```
Input: Text/Image conditioning
       |
       v
+---------------------------+
| Condition Encoder         |  (CLIP or similar)
+---------------------------+
       |
       v
+---------------------------+
| Sparse Transformer        |  × num_layers
|  • Sparse Self-Attention  |  (3D windowed)
|  • Sparse Cross-Attention |  (to conditioning)
|  • Sparse FFN             |
+---------------------------+
       |
       v
+---------------------------+
| Rectified Flow Denoising  |  (10-20 steps)
+---------------------------+
       |
       v
+---------------------------+
| Decode SLAT → 3D Output   |  (Gaussian splats, mesh, or NeRF)
+---------------------------+
       |
       v
Output: 3D asset (splats/mesh/radiance field)
```

## Usage

    # Build TRELLIS model
    model = TRELLIS.build(
      voxel_resolution: 64,
      feature_dim: 32,
      num_layers: 12,
      num_heads: 8
    )

    # Sparse attention over occupied voxels
    attended = TRELLIS.sparse_attention(
      sparse_features,
      positions,
      window_size: 8
    )

    # Single rectified flow step
    x_t_minus_1 = TRELLIS.rectified_flow_step(
      model, x_t, t, conditioning
    )

    # Full generation
    output = TRELLIS.generate(
      model, params, conditioning,
      num_steps: 20
    )

## Supported Output Formats

- **3D Gaussian Splatting**: Fast, high-quality rendering
- **Mesh extraction**: Via marching cubes from density field
- **Radiance field**: NeRF-style volumetric representation

## References

- Paper: "TRELLIS: Structured 3D Latents for Scalable and Versatile 3D Generation"
- Authors: Xiang et al., Microsoft Research
- Year: 2024
- Project: https://trellis3d.github.io/

# `build_opt`

```elixir
@type build_opt() ::
  {:condition_dim, pos_integer()}
  | {:feature_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:max_voxels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:voxel_resolution, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build the TRELLIS model for 3D generation.

## Options

  - `:voxel_resolution` - Resolution of voxel grid (default: 64)
  - `:feature_dim` - Per-voxel feature dimension (default: 32)
  - `:hidden_size` - Transformer hidden dimension (default: 512)
  - `:num_layers` - Number of sparse transformer layers (default: 12)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:window_size` - Size of local attention window (default: 8)
  - `:condition_dim` - Conditioning vector dimension (default: 768)
  - `:mlp_ratio` - MLP expansion ratio (default: 4.0)
  - `:max_voxels` - Maximum number of occupied voxels (default: 8192)

## Returns

  An Axon model that takes sparse voxel features + conditioning and outputs
  denoised sparse features.

# `decode_from_slat`

```elixir
@spec decode_from_slat(
  map(),
  keyword()
) :: Nx.Tensor.t() | map()
```

Decode Sparse Structured Latent back to dense representation or output format.

## Parameters

  - `sparse_latent` - Map from `encode_to_slat/2` or model output
  - `opts` - Options including:
    - `:output_format` - `:dense`, `:gaussian_splats`, or `:mesh` (default: `:dense`)
    - `:resolution` - Output resolution for dense format (default: 64)

## Returns

  Decoded 3D representation in requested format.

# `encode_to_slat`

```elixir
@spec encode_to_slat(
  Nx.Tensor.t(),
  keyword()
) :: map()
```

Encode a dense voxel grid to Sparse Structured Latent (SLAT) representation.

## Parameters

  - `voxel_grid` - Dense voxel grid [batch, resolution, resolution, resolution, features]
                   or occupancy grid [batch, res, res, res]
  - `opts` - Options including `:threshold` for occupancy detection

## Returns

  A map with:
  - `:features` - Sparse features [batch, num_occupied, feature_dim]
  - `:positions` - Voxel positions [batch, num_occupied, 3]
  - `:mask` - Occupancy mask [batch, num_occupied]

# `generate`

```elixir
@spec generate(Axon.t(), map(), Nx.Tensor.t(), keyword()) :: map()
```

Generate 3D content using rectified flow sampling.

## Parameters

  - `model` - TRELLIS model
  - `params` - Model parameters
  - `conditioning` - Conditioning tensor [batch, cond_len, cond_dim]
  - `opts` - Options:
    - `:num_steps` - Number of denoising steps (default: 20)
    - `:max_voxels` - Maximum voxels in output (default: 8192)
    - `:feature_dim` - Feature dimension (default: 32)
    - `:voxel_resolution` - Resolution for position initialization (default: 64)

## Returns

  Generated sparse latent map ready for decoding.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output feature dimension.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Approximate parameter count for TRELLIS model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for TRELLIS.

# `rectified_flow_step`

```elixir
@spec rectified_flow_step(
  Axon.t(),
  map(),
  map(),
  Nx.Tensor.t(),
  Nx.Tensor.t(),
  keyword()
) :: map()
```

Perform one rectified flow denoising step.

Rectified flow uses straight-line interpolation:
- Forward: x_t = t * x_1 + (1-t) * x_0 (where x_0 is noise, x_1 is data)
- Model predicts velocity: v = x_1 - x_0
- Update: x_{t-dt} = x_t + dt * v

## Parameters

  - `model` - TRELLIS model
  - `params` - Model parameters
  - `x_t` - Current noisy sparse latent (map with :features, :positions, :mask)
  - `t` - Current timestep [batch] in [0, 1]
  - `conditioning` - Conditioning tensor [batch, cond_len, cond_dim]
  - `opts` - Options:
    - `:dt` - Step size (default: computed from num_steps)
    - `:num_steps` - Total steps for dt calculation (default: 20)

## Returns

  Denoised sparse latent at t - dt.

# `sparse_attention`

```elixir
@spec sparse_attention(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
```

Compute sparse windowed 3D attention over occupied voxels.

## Parameters

  - `sparse_features` - Voxel features [batch, num_voxels, feature_dim]
  - `positions` - Voxel positions [batch, num_voxels, 3]
  - `opts` - Options:
    - `:window_size` - Attention window size (default: 8)
    - `:num_heads` - Number of attention heads (default: 8)
    - `:mask` - Occupancy mask [batch, num_voxels] (optional)

## Returns

  Attended features [batch, num_voxels, feature_dim]

---

*Consult [api-reference.md](api-reference.md) for complete listing*
