Running LLMs with LoRA
Mix.install([
{:bumblebee, "~> 0.4.2"},
{:axon, "~> 0.6.0"},
{:nx, "~> 0.6.1"},
{:exla, "~> 0.6.1"},
{:explorer, "~> 0.7.0"},
{:lorax, "~> 0.1.0"},
{:req, "~> 0.4.0"},
{:kino, "~> 0.11.0"}
])
Nx.default_backend(EXLA.Backend)
Introduction
This notebook demonstrates how to run a text-generating model like GPT2 with LoRA. The basic steps:
- Define the model as it was during fine-tuning
- Merge the LoRA parameters with the base model weights.
- Run the model
We'll be using a LoRA file that was trained on the Elixir Chat/Discussion section. The data can be found here.
Load original model
{:ok, spec} = Bumblebee.load_spec({:hf, "gpt2"})
{:ok, model} = Bumblebee.load_model({:hf, "gpt2"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2"})
%{model: model, params: gpt2_params} = model
:ok
:ok
Define LoRA Model
Inject LoRA layers as it was during fine-tuning to ensure correct model compilation.
r = 4
lora_alpha = 8
lora_dropout = 0.05
lora_model =
model
|> Axon.freeze()
|> Lorax.inject(%Lorax.Config{
r: r,
alpha: lora_alpha,
dropout: lora_dropout,
target_key: true,
target_query: true,
target_value: true
})
#Axon<
inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
outputs: "container_37"
nodes: 895
>
Load LoRA Params
There's 3 main ways to load your LoRA params
- URL: Use whatever HTTP library to retrieve the binary, and Nx.deserialize to get the map of tensor values
- File: Similar flow as URL. use File.read to retrieve the binary, Nx.deserialize. For convenience, you can use Lorax.Params.file_load!()
- Kino: GUI file picker. This requires a two-step process. You need to have one Kino input cell, and another cell to read the file input.
See the code below for examples
# Method 1: URL downlaod
lora_serialized =
Req.get!("https://raw.githubusercontent.com/wtedw/lorax/main/params/elixir-chat-r4a8.lorax").body
lora_only_params = Nx.deserialize(lora_serialized)
lora_only_params |> Map.keys()
# Method 2: File download
# lora_only_params = Lorax.Params.file_load!("<insert local path>")
# Method 3: Kino
# cell #1
# input = Kino.Input.file("Lorax Params")
#
# cell #2
# lora_only_params = Lorax.Params.kino_file_load!(input)
# merged_params = Lorax.Params.merge_params(lora_only_params, gpt2_params)
["dropout_28", "decoder.blocks.5.self_attention_dropout", "lora_11", "lora_31", "lora_10",
"decoder.blocks.11.self_attention_dropout", "dropout_9", "lora_28", "dropout_25",
"decoder.blocks.1.self_attention_dropout", "lora_13", "lora_20", "lora_7", "dropout_6", "lora_19",
"decoder.blocks.2.self_attention_dropout", "lora_5", "lora_30",
"decoder.blocks.7.self_attention_dropout", "lora_2", "lora_24", "dropout_1", "lora_6",
"dropout_10", "dropout_3", "dropout_15", "lora_27", "dropout_7", "lora_35", "lora_22",
"dropout_21", "decoder.blocks.3.self_attention_dropout", "lora_0", "dropout_27", "dropout_24",
"lora_9", "decoder.blocks.4.self_attention_dropout", "lora_32", "dropout_34",
"decoder.blocks.10.self_attention_dropout", "lora_23", "dropout_0",
"decoder.blocks.6.self_attention_dropout", "lora_12", "dropout_16", "lora_34", "dropout_19",
"lora_18", "lora_3", "lora_17", ...]
Merge Params
Axon expects one single map of all the parameter values. So although we've loaded the LoRA params, we need to merge them with the original parameters.
merged_params = Map.merge(gpt2_params, lora_only_params)
:ok
:ok
Inference Prepwork
The sampling method used here is non-deterministic, so the output is different everytime you run this cell + the bottom cell. Using top_p value of 0.6 may generate more coherent sentences, but tends to repeat. 0.7 and 0.8 is a good sweet spot.
lora_model_info = %{model: lora_model, params: merged_params, spec: spec}
lora_generation_config =
Bumblebee.configure(generation_config,
max_new_tokens: 512,
strategy: %{type: :multinomial_sampling, top_p: 0.7}
)
serving =
Bumblebee.Text.generation(lora_model_info, tokenizer, lora_generation_config,
compile: [batch_size: 1, sequence_length: 512],
stream: true,
defn_options: [compiler: EXLA, lazy_transfers: :always]
)
Kino.start_child({Nx.Serving, name: Llama, serving: serving})
{:ok, #PID<0.9144.0>}
Inference Results
The training data is formatted like this:
<title>Ideas for an Open Source Discord</title>
<author>WildYorkies</author>
I remember seeing on the Elixir subreddit [...]
<likes>3 likes</likes>
You can simulate a thread by kickstarting the text generation with some <title>text</title>
string, and the LLM will give a prediction of what the general Elixir community will say.
Nx.Serving.batched_run(
Llama,
"<title>Elixir 2.0 is released! New features include</title>"
)
|> Enum.each(&IO.write/1)
<author>xjalasz</author>
Elixir 2.0 is released! New features include
This means that you can now deploy your Elixir 2 projects without having to use a tool like docker.
<likes>1 like</likes>
<author>jake</author>
As always, thanks for the help.
I am on a 10 day cruise to Toronto in July with the goal of finishing up Elixir 2.0 in less than one week.
<likes>1 like</likes>
<author>pianotato</author>
Thanks for all the support!
<likes>0 likes</likes>
<author>thekali</author>
Thanks for your time and effort.
<likes>0 likes</likes>
<author>mangai</author>
Thank you! I have no idea how long you have been on this blog.
<likes>1 like</likes>
<author>juicy_champion</author>
Thanks!
<likes>0 likes</likes>
<author>haha<likes>1 like</likes>
<author>thompson3</author>
Thank you for the support!
<likes>0 likes</likes>
<author>phrobotics</author>
Thanks for your time and effort.
<likes>0 likes</likes>
<author>yelp</author>
Thanks!
<likes>0 likes</likes>
<author>phrobotics</author>
Thanks!
<likes>0 likes</likes>
<author>jake</author>
Thanks for your time and effort.
<likes>0 likes</likes>
<author>jake</author>
Thanks!
<likes>0 likes</likes>
<author>juicy_champion</author>
Thanks!
<likes>0 likes</likes>
<author>juicy_champion</author>
Thanks!
<likes>0 likes</likes>
<author>gw
:ok