Running LLMs with LoRA

Mix.install([
  {:bumblebee, "~> 0.5.3"},
  {:axon, "~> 0.6.1"},
  {:nx, "~> 0.7.1"},
  {:exla, "~> 0.7.1"},
  {:explorer, "~> 0.7.0"},
  {:lorax, "~> 0.2.1"},
  {:req, "~> 0.4.0"},
  {:kino, "~> 0.12.3"}
])

Nx.default_backend(EXLA.Backend)

Introduction

This notebook demonstrates how to run a text-generating model like GPT2 with LoRA. The basic steps:

  1. Define the model as it was during fine-tuning
  2. Merge the LoRA parameters with the base model weights.
  3. Run the model

We'll be using a LoRA file that was trained on the Elixir Chat/Discussion section. The data can be found here.

Load original model

{:ok, spec} = Bumblebee.load_spec({:hf, "gpt2"})
{:ok, model} = Bumblebee.load_model({:hf, "gpt2"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2"})

%{model: model, params: gpt2_params} = model

:ok
:ok

Define LoRA Model

Inject LoRA layers as it was during fine-tuning to ensure correct model compilation.

r = 4
lora_alpha = 8
lora_dropout = 0.05

lora_model =
  model
  |> Axon.freeze()
  |> Lorax.inject(%Lorax.Config{
    r: r,
    alpha: lora_alpha,
    dropout: lora_dropout,
    target_key: true,
    target_query: true,
    target_value: true
  })
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "container_37"
  nodes: 895
>

Load LoRA Params

There's 3 main ways to load your LoRA params

  1. URL: Use whatever HTTP library to retrieve the binary, and Nx.deserialize to get the map of tensor values
  2. File: Similar flow as URL. use File.read to retrieve the binary, Nx.deserialize. For convenience, you can use Lorax.Params.file_load!()
  3. Kino: GUI file picker. This requires a two-step process. You need to have one Kino input cell, and another cell to read the file input.

See the code below for examples

# Method 1: URL downlaod
lora_serialized =
  Req.get!("https://raw.githubusercontent.com/wtedw/lorax/main/params/elixir-chat-r4a8.lorax").body
lora_only_params = Nx.deserialize(lora_serialized)
lora_only_params |> Map.keys()

# Method 2: File download
# lora_only_params = Lorax.Params.file_load!("<insert local path>")

# Method 3: Kino
# cell #1
# input = Kino.Input.file("Lorax Params")
#
# cell #2
# lora_only_params = Lorax.Params.kino_file_load!(input)
# merged_params = Lorax.Params.merge_params(lora_only_params, gpt2_params)
["dropout_28", "decoder.blocks.5.self_attention_dropout", "lora_11", "lora_31", "lora_10",
 "decoder.blocks.11.self_attention_dropout", "dropout_9", "lora_28", "dropout_25",
 "decoder.blocks.1.self_attention_dropout", "lora_13", "lora_20", "lora_7", "dropout_6", "lora_19",
 "decoder.blocks.2.self_attention_dropout", "lora_5", "lora_30",
 "decoder.blocks.7.self_attention_dropout", "lora_2", "lora_24", "dropout_1", "lora_6",
 "dropout_10", "dropout_3", "dropout_15", "lora_27", "dropout_7", "lora_35", "lora_22",
 "dropout_21", "decoder.blocks.3.self_attention_dropout", "lora_0", "dropout_27", "dropout_24",
 "lora_9", "decoder.blocks.4.self_attention_dropout", "lora_32", "dropout_34",
 "decoder.blocks.10.self_attention_dropout", "lora_23", "dropout_0",
 "decoder.blocks.6.self_attention_dropout", "lora_12", "dropout_16", "lora_34", "dropout_19",
 "lora_18", "lora_3", "lora_17", ...]

Merge Params

Axon expects one single map of all the parameter values. So although we've loaded the LoRA params, we need to merge them with the original parameters.

merged_params = Map.merge(gpt2_params, lora_only_params)

:ok
:ok

Inference Prepwork

The sampling method used here is non-deterministic, so the output is different everytime you run this cell + the bottom cell. Using top_p value of 0.6 may generate more coherent sentences, but tends to repeat. 0.7 and 0.8 is a good sweet spot.

lora_model_info = %{model: lora_model, params: merged_params, spec: spec}

lora_generation_config =
  Bumblebee.configure(generation_config,
    max_new_tokens: 512,
    strategy: %{type: :multinomial_sampling, top_p: 0.7}
  )

serving =
  Bumblebee.Text.generation(lora_model_info, tokenizer, lora_generation_config,
    compile: [batch_size: 1, sequence_length: 512],
    stream: true,
    defn_options: [compiler: EXLA, lazy_transfers: :always]
  )

Kino.start_child({Nx.Serving, name: Llama, serving: serving})
{:ok, #PID<0.9144.0>}

Inference Results

The training data is formatted like this:

<title>Ideas for an Open Source Discord</title>

<author>WildYorkies</author>

I remember seeing on the Elixir subreddit [...]

<likes>3 likes</likes>

You can simulate a thread by kickstarting the text generation with some <title>text</title> string, and the LLM will give a prediction of what the general Elixir community will say.

Nx.Serving.batched_run(
  Llama,
  "<title>Elixir 2.0 is released! New features include</title>"
)
|> Enum.each(&IO.write/1)


<author>xjalasz</author>

Elixir 2.0 is released! New features include
This means that you can now deploy your Elixir 2 projects without having to use a tool like docker.

<likes>1 like</likes>

<author>jake</author>

As always, thanks for the help.
I am on a 10 day cruise to Toronto in July with the goal of finishing up Elixir 2.0 in less than one week.

<likes>1 like</likes>

<author>pianotato</author>

Thanks for all the support!

<likes>0 likes</likes>

<author>thekali</author>

Thanks for your time and effort.

<likes>0 likes</likes>

<author>mangai</author>

Thank you! I have no idea how long you have been on this blog.

<likes>1 like</likes>

<author>juicy_champion</author>

Thanks!

<likes>0 likes</likes>

<author>haha<likes>1 like</likes>

<author>thompson3</author>

Thank you for the support!

<likes>0 likes</likes>

<author>phrobotics</author>

Thanks for your time and effort.

<likes>0 likes</likes>

<author>yelp</author>

Thanks!

<likes>0 likes</likes>

<author>phrobotics</author>

Thanks!

<likes>0 likes</likes>

<author>jake</author>

Thanks for your time and effort.

<likes>0 likes</likes>

<author>jake</author>

Thanks!

<likes>0 likes</likes>

<author>juicy_champion</author>

Thanks!

<likes>0 likes</likes>

<author>juicy_champion</author>

Thanks!

<likes>0 likes</likes>

<author>gw
:ok