Stable Diffusion w/ LCM (Realtime)

Mix.install(
  [
    {:bumblebee, git: "https://github.com/elixir-nx/bumblebee.git"},
    {:nx, "~> 0.6.1", override: true},
    {:exla, "~> 0.6.1"},
    {:kino, "~> 0.11.0"},
    {:lorax, git: "https://github.com/wtedw/lorax.git"},
    {:axon, [env: :prod, git: "https://github.com/elixir-nx/axon.git", override: true]},
    {:req, "~> 0.4.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Nx.global_default_backend(EXLA.Backend)

Load SD1.5, LCM LoRA, and LCM scheduler

LCM LoRA requires two key things

  1. Loading the LoRA adapter for the appropriate Stable Diffusion unet model. We'll be using the LoRA file for SD 1.5
  2. LCMScheduler. The adapter will not work if any other scheduler is used.

We'll first download the safetensors from HF and call the Lorax library to convert the parameters to something Axon can use. In addition to converting the params, Lorax will provide a config that describes how new parameters should be injected into the Axon model.

repo_id = "runwayml/stable-diffusion-v1-5"
opts = [params_variant: "fp16"]

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repo_id, subdir: "text_encoder"}, opts)
{:ok, unet} = Bumblebee.load_model({:hf, repo_id, subdir: "unet"}, opts)
{:ok, vae} = Bumblebee.load_model({:hf, repo_id, subdir: "vae"}, [architecture: :decoder] ++ opts)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repo_id, subdir: "scheduler"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, repo_id, subdir: "feature_extractor"})
{:ok, safety_checker} = Bumblebee.load_model({:hf, repo_id, subdir: "safety_checker"}, opts)

# Option #1, Download LCM LoRA from HF
resp =
  Req.get!(
    "https://huggingface.co/latent-consistency/lcm-lora-sdv1-5/resolve/main/pytorch_lora_weights.safetensors?download=true"
  )

param_data = resp.body
{config, lcm_lora_params} = Lorax.Lcm.load!(param_data)

# Option #2, Load Locally
# {config, lcm_lora_params} =
#   Lorax.Lcm.read!("/Users/[user]/.../pytorch_lora_weights.safetensors")

# Axon expects one map containing all the layer names -> tensors. 
# We'll merge the LCM params with the original SD params.
lcm_unet = %{
  unet
  | model: Lorax.inject(unet.model, config),
    params: Map.merge(lcm_lora_params, unet.params)
}

lcm_scheduler = %Bumblebee.Diffusion.LcmScheduler{}

Create serving

serving =
  Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, lcm_unet, vae, tokenizer, lcm_scheduler,
    num_steps: 4,
    num_images_per_prompt: 1,
    safety_checker_featurizer: featurizer,
    guidance_scale: 1.0,
    compile: [batch_size: 1, sequence_length: 60],
    defn_options: [compiler: EXLA]
  )

Realtime LCM LoRA

Run all the following cells and begin typing in the prompt below. The Kino frame beneath the prompt should reflect new images. Debounce value can be changed to control how fast the images should refresh

This has been tested on the NVIDIA A10G and above. Make sure you use a high end graphics card for this or else the frame below will be blank for quite a long time.

input = Kino.Input.textarea("Prompt", debounce: 150)
# Images will be rendered in this frame
frame = Kino.Frame.new()
Kino.listen(input, fn %{value: prompt} ->
  # Generate the image and put it in the frame
  output = Nx.Serving.run(serving, %{prompt: prompt})
  [result] = output.results
  image = Kino.Image.new(result.image)
  Kino.Frame.render(frame, image)
end)