SiT: Scalable Interpolant Transformer.
Implements SiT from "Scalable Interpolant Transformers" (Ma et al., 2024). Generalizes DiT by learning the interpolant between noise and data, rather than predicting just the score or velocity with a fixed schedule.
Key Innovation: Learnable Interpolant
Instead of a fixed diffusion schedule, SiT uses:
I(t) = α(t) · x + β(t) · εwhere α and β are schedules (default linear: α(t)=1-t, β(t)=t). The model predicts the interpolant velocity:
v(x, t) = dI/dt = α'(t) · x + β'(t) · εThis subsumes both DDPM (score prediction) and flow matching (velocity prediction) as special cases depending on the choice of α, β.
Architecture
Same transformer backbone as DiT:
Input [batch, input_dim]
|
v
+--------------------------+
| Input Embed + Pos Embed |
+--------------------------+
|
v
+--------------------------+
| SiT Block x depth |
| AdaLN-Zero(time_cond) |
| Self-Attention |
| Residual |
| AdaLN-Zero(time_cond) |
| MLP |
| Residual |
+--------------------------+
|
v
| Final Norm + Linear |
|
v
Output [batch, input_dim] (predicted velocity)Interpolant Schedules
- Linear (default): α(t) = 1-t, β(t) = t → simple, matches flow matching
- Cosine: α(t) = cos(πt/2), β(t) = sin(πt/2) → smoother transitions
- Custom: user-provided α(t), β(t) functions
Usage
model = SiT.build(
input_dim: 64,
hidden_size: 256,
num_layers: 6,
num_heads: 4
)
# Training: sample time, compute interpolant and target velocity
t = SiT.sample_interpolant_time(batch_size)
loss = SiT.sit_loss(predicted_velocity, target_velocity)References
- Paper: "Scalable Interpolant Transformers"
- arXiv: https://arxiv.org/abs/2401.08740
Summary
Functions
Build a SiT model for interpolant-based generation.
Compute the cosine interpolant between data and noise.
Compute the target velocity for the cosine interpolant.
Compute the linear interpolant between data and noise.
Compute the target velocity for the linear interpolant.
Get the output size of a SiT model.
Sample interpolant time t ~ Uniform(0, 1).
Compute the SiT loss (MSE between predicted and target velocity).
Types
@type build_opt() :: {:depth, pos_integer()} | {:hidden_size, pos_integer()} | {:input_dim, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, pos_integer()} | {:num_steps, pos_integer()}
Options for build/1.
Functions
Build a SiT model for interpolant-based generation.
Options
:input_dim- Input/output feature dimension (required):hidden_size- Transformer hidden dimension (default: 256):num_layers- Number of SiT blocks (default: 6):num_heads- Number of attention heads (default: 4):mlp_ratio- MLP expansion ratio (default: 4.0):num_classes- Number of classes for conditioning (optional):num_steps- Number of timesteps for embedding (default: 1000)
Returns
An Axon model that predicts velocity given (noisy_input, timestep, [class]).
@spec cosine_interpolant(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the cosine interpolant between data and noise.
I(t) = cos(πt/2) · x + sin(πt/2) · ε
Parameters
x- Clean data: [batch, dim]noise- Random noise: [batch, dim]t- Interpolant time: [batch] (values in [0, 1])
Returns
Interpolated tensor with same shape as x.
@spec cosine_velocity(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the target velocity for the cosine interpolant.
v(t) = -(π/2)·sin(πt/2)·x + (π/2)·cos(πt/2)·ε
Parameters
x- Clean data: [batch, dim]noise- Random noise: [batch, dim]t- Interpolant time: [batch] (values in [0, 1])
Returns
Target velocity tensor with same shape as x.
@spec linear_interpolant(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the linear interpolant between data and noise.
I(t) = (1 - t) · x + t · ε (linear schedule)
Parameters
x- Clean data: [batch, dim]noise- Random noise: [batch, dim]t- Interpolant time: [batch] (values in [0, 1])
Returns
Interpolated tensor with same shape as x.
@spec linear_velocity(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the target velocity for the linear interpolant.
v(t) = dI/dt = -x + ε (linear schedule derivative)
Parameters
x- Clean data: [batch, dim]noise- Random noise: [batch, dim]
Returns
Target velocity tensor with same shape as x.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a SiT model.
@spec sample_interpolant_time( pos_integer(), keyword() ) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Sample interpolant time t ~ Uniform(0, 1).
Parameters
batch_size- Number of samples
Options
:key- Random key (default:Nx.Random.key(System.system_time()))
Returns
{t, new_key} where t has shape {batch_size}.
@spec sit_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the SiT loss (MSE between predicted and target velocity).
Parameters
pred_velocity- Model-predicted velocity: [batch, dim]target_velocity- Target velocity: [batch, dim]
Returns
Scalar MSE loss tensor.