# `Edifice.Generative.SiT`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/generative/sit.ex#L1)

SiT: Scalable Interpolant Transformer.

Implements SiT from "Scalable Interpolant Transformers" (Ma et al., 2024).
Generalizes DiT by learning the interpolant between noise and data, rather
than predicting just the score or velocity with a fixed schedule.

## Key Innovation: Learnable Interpolant

Instead of a fixed diffusion schedule, SiT uses:

    I(t) = α(t) · x + β(t) · ε

where α and β are schedules (default linear: α(t)=1-t, β(t)=t). The model
predicts the interpolant velocity:

    v(x, t) = dI/dt = α'(t) · x + β'(t) · ε

This subsumes both DDPM (score prediction) and flow matching (velocity
prediction) as special cases depending on the choice of α, β.

## Architecture

Same transformer backbone as DiT:

```
Input [batch, input_dim]
      |
      v
+--------------------------+
| Input Embed + Pos Embed  |
+--------------------------+
      |
      v
+--------------------------+
| SiT Block x depth        |
|  AdaLN-Zero(time_cond)   |
|  Self-Attention          |
|  Residual                |
|  AdaLN-Zero(time_cond)   |
|  MLP                     |
|  Residual                |
+--------------------------+
      |
      v
| Final Norm + Linear     |
      |
      v
Output [batch, input_dim]  (predicted velocity)
```

## Interpolant Schedules

- Linear (default): α(t) = 1-t, β(t) = t  →  simple, matches flow matching
- Cosine: α(t) = cos(πt/2), β(t) = sin(πt/2)  →  smoother transitions
- Custom: user-provided α(t), β(t) functions

## Usage

    model = SiT.build(
      input_dim: 64,
      hidden_size: 256,
      num_layers: 6,
      num_heads: 4
    )

    # Training: sample time, compute interpolant and target velocity
    t = SiT.sample_interpolant_time(batch_size)
    loss = SiT.sit_loss(predicted_velocity, target_velocity)

## References

- Paper: "Scalable Interpolant Transformers"
- arXiv: https://arxiv.org/abs/2401.08740

# `build_opt`

```elixir
@type build_opt() ::
  {:depth, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:num_steps, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a SiT model for interpolant-based generation.

## Options

  - `:input_dim` - Input/output feature dimension (required)
  - `:hidden_size` - Transformer hidden dimension (default: 256)
  - `:num_layers` - Number of SiT blocks (default: 6)
  - `:num_heads` - Number of attention heads (default: 4)
  - `:mlp_ratio` - MLP expansion ratio (default: 4.0)
  - `:num_classes` - Number of classes for conditioning (optional)
  - `:num_steps` - Number of timesteps for embedding (default: 1000)

## Returns

  An Axon model that predicts velocity given (noisy_input, timestep, [class]).

# `cosine_interpolant`

```elixir
@spec cosine_interpolant(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the cosine interpolant between data and noise.

I(t) = cos(πt/2) · x + sin(πt/2) · ε

## Parameters

  - `x` - Clean data: [batch, dim]
  - `noise` - Random noise: [batch, dim]
  - `t` - Interpolant time: [batch] (values in [0, 1])

## Returns

  Interpolated tensor with same shape as x.

# `cosine_velocity`

```elixir
@spec cosine_velocity(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the target velocity for the cosine interpolant.

v(t) = -(π/2)·sin(πt/2)·x + (π/2)·cos(πt/2)·ε

## Parameters

  - `x` - Clean data: [batch, dim]
  - `noise` - Random noise: [batch, dim]
  - `t` - Interpolant time: [batch] (values in [0, 1])

## Returns

  Target velocity tensor with same shape as x.

# `linear_interpolant`

```elixir
@spec linear_interpolant(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the linear interpolant between data and noise.

I(t) = (1 - t) · x + t · ε  (linear schedule)

## Parameters

  - `x` - Clean data: [batch, dim]
  - `noise` - Random noise: [batch, dim]
  - `t` - Interpolant time: [batch] (values in [0, 1])

## Returns

  Interpolated tensor with same shape as x.

# `linear_velocity`

```elixir
@spec linear_velocity(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the target velocity for the linear interpolant.

v(t) = dI/dt = -x + ε  (linear schedule derivative)

## Parameters

  - `x` - Clean data: [batch, dim]
  - `noise` - Random noise: [batch, dim]

## Returns

  Target velocity tensor with same shape as x.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a SiT model.

# `sample_interpolant_time`

```elixir
@spec sample_interpolant_time(
  pos_integer(),
  keyword()
) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Sample interpolant time t ~ Uniform(0, 1).

## Parameters

  - `batch_size` - Number of samples

## Options

  - `:key` - Random key (default: `Nx.Random.key(System.system_time())`)

## Returns

  `{t, new_key}` where t has shape `{batch_size}`.

# `sit_loss`

```elixir
@spec sit_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the SiT loss (MSE between predicted and target velocity).

## Parameters

  - `pred_velocity` - Model-predicted velocity: [batch, dim]
  - `target_velocity` - Target velocity: [batch, dim]

## Returns

  Scalar MSE loss tensor.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
