View Source Bumblebee.Diffusion.StableDiffusion (Bumblebee v0.4.2)

High-level tasks based on Stable Diffusion.

Summary

Functions

Build serving for prompt-driven image generation.

Types

@type text_to_image_input() ::
  String.t()
  | %{:prompt => String.t(), optional(:negative_prompt) => String.t()}
Link to this type

text_to_image_output()

View Source
@type text_to_image_output() :: %{results: [text_to_image_result()]}
Link to this type

text_to_image_result()

View Source
@type text_to_image_result() :: %{
  :image => Nx.Tensor.t(),
  optional(:is_safe) => boolean()
}

Functions

Link to this function

text_to_image(encoder, unet, vae, tokenizer, scheduler, opts \\ [])

View Source

Build serving for prompt-driven image generation.

The serving accepts text_to_image_input/0 and returns text_to_image_output/0. A list of inputs is also supported.

You can specify :safety_checker model to automatically detect when a generated image is offensive or harmful and filter it out.

Options

  • :safety_checker - the safety checker model info map. When a safety checker is used, each output entry has an additional :is_safe property and unsafe images are automatically zeroed. Make sure to also set :safety_checker_featurizer

  • :safety_checker_featurizer - the featurizer to use to preprocess the safety checker input images

  • :num_steps - the number of denoising steps. More denoising steps usually lead to higher image quality at the expense of slower inference. Defaults to 50

  • :num_images_per_prompt - the number of images to generate for each prompt. Defaults to 1

  • :guidance_scale - the scale used for classifier-free diffusion guidance. Higher guidance scale makes the generated images more closely reflect the text prompt. This parameter corresponds to $\omega$ in Equation (2) of the Imagen paper. Defaults to 7.5

  • :seed - a seed for the random number generator. Defaults to 0

  • :compile - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys:

    • :batch_size - the maximum batch size of the input. Inputs are optionally padded to always match this batch size

    • :sequence_length - the maximum input sequence length. Input sequences are always padded/truncated to match that length

    It is advised to set this option in production and also configure a defn compiler using :defn_options to maximally reduce inference time.

  • :defn_options - the options for JIT compilation. Defaults to []

  • :preallocate_params - when true, explicitly allocates params on the device configured by :defn_options. You may want to set this option when using partitioned serving, to allocate params on each of the devices. Defaults to false

Examples

repository_id = "CompVis/stable-diffusion-v1-4"

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})

{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})

{:ok, unet} =
  Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
    params_filename: "diffusion_pytorch_model.bin"
  )

{:ok, vae} =
  Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
    architecture: :decoder,
    params_filename: "diffusion_pytorch_model.bin"
  )

{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})

serving =
  Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
    num_steps: 20,
    num_images_per_prompt: 2,
    safety_checker: safety_checker,
    safety_checker_featurizer: featurizer,
    compile: [batch_size: 1, sequence_length: 60],
    defn_options: [compiler: EXLA]
  )

prompt = "numbat in forest, detailed, digital art"
Nx.Serving.run(serving, prompt)
#=> %{
#=>   results: [
#=>     %{
#=>       image: #Nx.Tensor<
#=>         u8[512][512][3]
#=>         ...
#=>       >,
#=>       is_safe: true
#=>     },
#=>     %{
#=>       image: #Nx.Tensor<
#=>         u8[512][512][3]
#=>         ...
#=>       >,
#=>       is_safe: true
#=>     }
#=>   ]
#=> }