Stable Diffusion w/ LCM LoRA

Mix.install(
  [
    {:bumblebee, git: "https://github.com/elixir-nx/bumblebee.git"},
    {:nx, "~> 0.6.1", override: true},
    {:exla, "~> 0.6.1"},
    {:kino, "~> 0.11.0"},
    {:lorax, git: "https://github.com/wtedw/lorax.git"},
    {:axon, [env: :prod, git: "https://github.com/elixir-nx/axon.git", override: true]},
    {:req, "~> 0.4.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Nx.global_default_backend(EXLA.Backend)

Load SD1.5, LCM LoRA, and LCM scheduler

LCM LoRA requires two key things

  1. Loading the LoRA adapter for the appropriate Stable Diffusion unet model. We'll be using the LoRA file for SD 1.5
  2. LCMScheduler. The adapter will not work if any other scheduler is used.

We'll first download the safetensors from HF and call the Lorax library to convert the parameters to something Axon can use. In addition to converting the params, Lorax will provide a config that describes how new parameters should be injected into the Axon model.

repo_id = "runwayml/stable-diffusion-v1-5"
opts = [params_variant: "fp16"]

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repo_id, subdir: "text_encoder"}, opts)
{:ok, unet} = Bumblebee.load_model({:hf, repo_id, subdir: "unet"}, opts)
{:ok, vae} = Bumblebee.load_model({:hf, repo_id, subdir: "vae"}, [architecture: :decoder] ++ opts)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repo_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repo_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repo_id, subdir: "safety_checker"}, opts)

# Option #1, Download LCM LoRA from HF
resp =
  Req.get!(
    "https://huggingface.co/latent-consistency/lcm-lora-sdv1-5/resolve/main/pytorch_lora_weights.safetensors?download=true"
  )

param_data = resp.body
{config, lcm_lora_params} = Lorax.Lcm.load!(param_data)

# Option #2, Load Locally
# {config, lcm_lora_params} =
#   Lorax.Lcm.read!("/Users/[user]/.../pytorch_lora_weights.safetensors")

# Axon expects one map containing all the layer names -> tensors. 
# We'll merge the LCM params with the original SD params.
lcm_unet = %{
  unet
  | model: Lorax.inject(unet.model, config),
    params: Map.merge(lcm_lora_params, unet.params)
}

lcm_scheduler = %Bumblebee.Diffusion.LcmScheduler{}

Create serving + Kino prompts

serving =
  Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, lcm_unet, vae, tokenizer, lcm_scheduler,
    num_steps: 4,
    num_images_per_prompt: 1,
    safety_checker_featurizer: featurizer,
    guidance_scale: 1.0,
    compile: [batch_size: 1, sequence_length: 60],
    defn_options: [compiler: EXLA]
  )

# compare this with regular SD 1.5
# serving =
#   Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
#     num_steps: 4,
#     num_images_per_prompt: 1,
#     safety_checker_featurizer: featurizer,
#     compile: [batch_size: 1, sequence_length: 60],
#     defn_options: [compiler: EXLA]
#   )

prompt_input =
  Kino.Input.text("Prompt",
    default: "astronaut in the desert, high quality, detailed"
  )

negative_prompt_input = Kino.Input.text("Negative Prompt", default: "blurry")

Kino.Layout.grid([prompt_input, negative_prompt_input])

Text to Image

some notes about generation

  • Although LCM LoRA speeds up image generation, it's sensitive to "correct" prompts.
  • Some random seeds produce good images, some are complete gibberish (especially if you only have one word like "turtle" as your prompt).
  • The LCM folks say 2-8 steps is a good range, but I found that 4 steps and guidance_scale = 1.0 is best. 8 steps causes too much denoising to happen, so the images appear more saturated.

We are ready to generate images!

prompt = Kino.Input.read(prompt_input)
negative_prompt = Kino.Input.read(negative_prompt_input)
output = Nx.Serving.run(serving, %{prompt: prompt, negative_prompt: negative_prompt})

for result <- output.results do
  Kino.Image.new(result.image)
end
|> Kino.Layout.grid(columns: 2)