View Source Bumblebee.Scheduler behaviour (Bumblebee v0.6.0)
An interface for configuring and using schedulers.
A scheduler defines a sampling method, usually used for multi-step denoising process, as in stable diffusion.
Every module implementing this behaviour is expected to also define a configuration struct.
Context
Imagine a denoising model trained in 1000 steps. During training, we take some original data and add random noise 1000 times, this way we obtain 1000 steps with increasing level of noise. Then, the model learns to predict noise at each timestep, given data at that step (sample) and the timestep.
Once such model is trained, we can obtain brand new data (such as image) by generating random data and denoising it with our model in 1000 steps.
Doing 1000 forward passes of the model for a single generation can be expensive, hence multiple methods have been developed to reduce the number of steps during denoising, with no changes to the model.
Each method specifies a subset of the original timesteps, at each timestep we need to do a forward pass of the model (or possibly a few), then the method extrapolates the sample to the next selected timestep, possibly skipping a lot of timesteps in between.
Note on wording
Throughout the docs and APIs the word "steps" refers to diffusion steps, whereas "timesteps" is more specific and refers to the exact values $t$ (points in time).
Summary
Types
@type state() :: Nx.Container.t()
@type t() :: Bumblebee.Configurable.t()
Callbacks
@callback init( t(), num_steps :: pos_integer(), sample_template :: Nx.Tensor.t(), prng_key :: Nx.Tensor.t() ) :: {state :: map(), timesteps :: Nx.Tensor.t()}
Initializes state for a new scheduler loop.
Returns a pair of {state, timesteps}
, where state
is an opaque
Nx.Container
and timesteps
is a tensor with the subsequent
timesteps for model forward pass.
@callback step( t(), state(), sample :: Nx.Tensor.t(), prediction :: Nx.Tensor.t() ) :: {state :: map(), prev_sample :: Nx.Tensor.t()}
Predicts sample at the previous timestep.
Takes the current sample
and prediction
(usually noise) returned
by the model at the current timestep. Returns {state, prev_sample}
,
where state
is the updated state and prev_sample
is the predicted
sample at the previous timestep.