View Source Bumblebee.Diffusion.StableDiffusion (Bumblebee v0.1.2)
High-level tasks based on Stable Diffusion.
Link to this section Summary
Functions
Build serving for prompt-driven image generation.
Link to this section Types
@type text_to_image_output() :: %{results: [text_to_image_result()]}
@type text_to_image_result() :: %{ :image => Nx.Tensor.t(), optional(:is_safe) => boolean() }
Link to this section Functions
text_to_image(encoder, unet, vae, tokenizer, scheduler, opts \\ [])
View Source@spec text_to_image( Bumblebee.model_info(), Bumblebee.model_info(), Bumblebee.model_info(), Bumblebee.Tokenizer.t(), Bumblebee.Scheduler.t(), keyword() ) :: Nx.Serving.t()
Build serving for prompt-driven image generation.
The serving accepts text_to_image_input/0
and returns text_to_image_output/0
.
A list of inputs is also supported.
You can specify :safety_checker
model to automatically detect
when a generated image is offensive or harmful and filter it out.
options
Options
:safety_checker
- the safety checker model info map. When a safety checker is used, each output entry has an additional:is_safe
property and unsafe images are automatically zeroed. Make sure to also set:safety_checker_featurizer
:safety_checker_featurizer
- the featurizer to use to preprocess the safety checker input images:num_steps
- the number of denoising steps. More denoising steps usually lead to higher image quality at the expense of slower inference. Defaults to50
:num_images_per_prompt
- the number of images to generate for each prompt. Defaults to1
:guidance_scale
- the scale used for classifier-free diffusion guidance. Higher guidance scale makes the generated images more closely reflect the text prompt. This parameter corresponds to $\omega$ in Equation (2) of the Imagen paper. Defaults to7.5
:seed
- a seed for the random number generator. Defaults to0
:compile
- compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys::batch_size
- the maximum batch size of the input. Inputs are optionally padded to always match this batch size:sequence_length
- the maximum input sequence length. Input sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure a defn compiler using
:defn_options
to maximally reduce inference time.:defn_options
- the options for JIT compilation. Defaults to[]
examples
Examples
repository_id = "CompVis/stable-diffusion-v1-4"
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, vae} =
Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
architecture: :decoder,
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})
serving =
Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
num_steps: 20,
num_images_per_prompt: 2,
safety_checker: safety_checker,
safety_checker_featurizer: featurizer,
compile: [batch_size: 1, sequence_length: 60],
defn_options: [compiler: EXLA]
)
prompt = "numbat in forest, detailed, digital art"
Nx.Serving.run(serving, prompt)
#=> %{
#=> results: [
#=> %{
#=> image: #Nx.Tensor<
#=> u8[512][512][3]
#=> ...
#=> >,
#=> is_safe: true
#=> },
#=> %{
#=> image: #Nx.Tensor<
#=> u8[512][512][3]
#=> ...
#=> >,
#=> is_safe: true
#=> }
#=> ]
#=> }