Edifice.Generative.SiT (Edifice v0.2.0)

Copy Markdown View Source

SiT: Scalable Interpolant Transformer.

Implements SiT from "Scalable Interpolant Transformers" (Ma et al., 2024). Generalizes DiT by learning the interpolant between noise and data, rather than predicting just the score or velocity with a fixed schedule.

Key Innovation: Learnable Interpolant

Instead of a fixed diffusion schedule, SiT uses:

I(t) = α(t) · x + β(t) · ε

where α and β are schedules (default linear: α(t)=1-t, β(t)=t). The model predicts the interpolant velocity:

v(x, t) = dI/dt = α'(t) · x + β'(t) · ε

This subsumes both DDPM (score prediction) and flow matching (velocity prediction) as special cases depending on the choice of α, β.

Architecture

Same transformer backbone as DiT:

Input [batch, input_dim]
      |
      v
+--------------------------+
| Input Embed + Pos Embed  |
+--------------------------+
      |
      v
+--------------------------+
| SiT Block x depth        |
|  AdaLN-Zero(time_cond)   |
|  Self-Attention          |
|  Residual                |
|  AdaLN-Zero(time_cond)   |
|  MLP                     |
|  Residual                |
+--------------------------+
      |
      v
| Final Norm + Linear     |
      |
      v
Output [batch, input_dim]  (predicted velocity)

Interpolant Schedules

  • Linear (default): α(t) = 1-t, β(t) = t → simple, matches flow matching
  • Cosine: α(t) = cos(πt/2), β(t) = sin(πt/2) → smoother transitions
  • Custom: user-provided α(t), β(t) functions

Usage

model = SiT.build(
  input_dim: 64,
  hidden_size: 256,
  num_layers: 6,
  num_heads: 4
)

# Training: sample time, compute interpolant and target velocity
t = SiT.sample_interpolant_time(batch_size)
loss = SiT.sit_loss(predicted_velocity, target_velocity)

References

Summary

Types

Options for build/1.

Functions

Build a SiT model for interpolant-based generation.

Compute the cosine interpolant between data and noise.

Compute the target velocity for the cosine interpolant.

Compute the linear interpolant between data and noise.

Compute the target velocity for the linear interpolant.

Get the output size of a SiT model.

Sample interpolant time t ~ Uniform(0, 1).

Compute the SiT loss (MSE between predicted and target velocity).

Types

build_opt()

@type build_opt() ::
  {:depth, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:num_steps, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a SiT model for interpolant-based generation.

Options

  • :input_dim - Input/output feature dimension (required)
  • :hidden_size - Transformer hidden dimension (default: 256)
  • :num_layers - Number of SiT blocks (default: 6)
  • :num_heads - Number of attention heads (default: 4)
  • :mlp_ratio - MLP expansion ratio (default: 4.0)
  • :num_classes - Number of classes for conditioning (optional)
  • :num_steps - Number of timesteps for embedding (default: 1000)

Returns

An Axon model that predicts velocity given (noisy_input, timestep, [class]).

cosine_interpolant(x, noise, t)

@spec cosine_interpolant(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the cosine interpolant between data and noise.

I(t) = cos(πt/2) · x + sin(πt/2) · ε

Parameters

  • x - Clean data: [batch, dim]
  • noise - Random noise: [batch, dim]
  • t - Interpolant time: [batch] (values in [0, 1])

Returns

Interpolated tensor with same shape as x.

cosine_velocity(x, noise, t)

@spec cosine_velocity(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the target velocity for the cosine interpolant.

v(t) = -(π/2)·sin(πt/2)·x + (π/2)·cos(πt/2)·ε

Parameters

  • x - Clean data: [batch, dim]
  • noise - Random noise: [batch, dim]
  • t - Interpolant time: [batch] (values in [0, 1])

Returns

Target velocity tensor with same shape as x.

linear_interpolant(x, noise, t)

@spec linear_interpolant(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the linear interpolant between data and noise.

I(t) = (1 - t) · x + t · ε (linear schedule)

Parameters

  • x - Clean data: [batch, dim]
  • noise - Random noise: [batch, dim]
  • t - Interpolant time: [batch] (values in [0, 1])

Returns

Interpolated tensor with same shape as x.

linear_velocity(x, noise)

@spec linear_velocity(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the target velocity for the linear interpolant.

v(t) = dI/dt = -x + ε (linear schedule derivative)

Parameters

  • x - Clean data: [batch, dim]
  • noise - Random noise: [batch, dim]

Returns

Target velocity tensor with same shape as x.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a SiT model.

sample_interpolant_time(batch_size, opts \\ [])

@spec sample_interpolant_time(
  pos_integer(),
  keyword()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Sample interpolant time t ~ Uniform(0, 1).

Parameters

  • batch_size - Number of samples

Options

  • :key - Random key (default: Nx.Random.key(System.system_time()))

Returns

{t, new_key} where t has shape {batch_size}.

sit_loss(pred_velocity, target_velocity)

@spec sit_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the SiT loss (MSE between predicted and target velocity).

Parameters

  • pred_velocity - Model-predicted velocity: [batch, dim]
  • target_velocity - Target velocity: [batch, dim]

Returns

Scalar MSE loss tensor.