View Source Bumblebee.Diffusion.StableDiffusion (Bumblebee v0.6.0)

High-level tasks based on Stable Diffusion.

Summary

Functions

Build serving for prompt-driven image generation.

Types

@type text_to_image_input() ::
  String.t()
  | %{
      :prompt => String.t(),
      optional(:negative_prompt) => String.t() | nil,
      optional(:seed) => integer() | nil
    }
Link to this type

text_to_image_output()

View Source
@type text_to_image_output() :: %{results: [text_to_image_result()]}
Link to this type

text_to_image_result()

View Source
@type text_to_image_result() :: %{
  :image => Nx.Tensor.t(),
  optional(:is_safe) => boolean()
}

Functions

Link to this function

text_to_image(encoder, unet, vae, tokenizer, scheduler, opts \\ [])

View Source

Build serving for prompt-driven image generation.

The serving accepts text_to_image_input/0 and returns text_to_image_output/0. A list of inputs is also supported.

You can specify :safety_checker model to automatically detect when a generated image is offensive or harmful and filter it out.

Options

  • :safety_checker - the safety checker model info map. When a safety checker is used, each output entry has an additional :is_safe property and unsafe images are automatically zeroed. Make sure to also set :safety_checker_featurizer

  • :safety_checker_featurizer - the featurizer to use to preprocess the safety checker input images

  • :num_steps - the number of denoising steps. More denoising steps usually lead to higher image quality at the expense of slower inference. Defaults to 50

  • :num_images_per_prompt - the number of images to generate for each prompt. Defaults to 1

  • :guidance_scale - the scale used for classifier-free diffusion guidance. Higher guidance scale makes the generated images more closely reflect the text prompt. This parameter corresponds to $\omega$ in Equation (2) of the Imagen paper. Defaults to 7.5

  • :compile - compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys:

    • :batch_size - the maximum batch size of the input. Inputs are optionally padded to always match this batch size

    • :sequence_length - the maximum input sequence length. Input sequences are always padded/truncated to match that length

    It is advised to set this option in production and also configure a defn compiler using :defn_options to maximally reduce inference time.

  • :defn_options - the options for JIT compilation. Defaults to []

  • :preallocate_params - when true, explicitly allocates params on the device configured by :defn_options. You may want to set this option when using partitioned serving, to allocate params on each of the devices. When using this option, you should first load the parameters into the host. This can be done by passing backend: {EXLA.Backend, client: :host} to load_model/1 and friends. Defaults to false

Examples

repository_id = "CompVis/stable-diffusion-v1-4"

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
{:ok, unet} = Bumblebee.load_model({:hf, repository_id, subdir: "unet"})
{:ok, vae} = Bumblebee.load_model({:hf, repository_id, subdir: "vae"}, architecture: :decoder)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repository_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repository_id, subdir: "safety_checker"})

serving =
  Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
    num_steps: 20,
    num_images_per_prompt: 2,
    safety_checker: safety_checker,
    safety_checker_featurizer: featurizer,
    compile: [batch_size: 1, sequence_length: 60],
    defn_options: [compiler: EXLA]
  )

prompt = "numbat in forest, detailed, digital art"
Nx.Serving.run(serving, prompt)
#=> %{
#=>   results: [
#=>     %{
#=>       image: #Nx.Tensor<
#=>         u8[512][512][3]
#=>         ...
#=>       >,
#=>       is_safe: true
#=>     },
#=>     %{
#=>       image: #Nx.Tensor<
#=>         u8[512][512][3]
#=>         ...
#=>       >,
#=>       is_safe: true
#=>     }
#=>   ]
#=> }