Finetuning LLMs with LoRA

Mix.install([
  {:bumblebee, "~> 0.5.3"},
  {:axon, "~> 0.6.1"},
  {:nx, "~> 0.7.1"},
  {:exla, "~> 0.7.1"},
  {:lorax, "~> 0.2.1"},
  {:req, "~> 0.4.0"},
  {:kino, "~> 0.12.3"}
])

Nx.default_backend(EXLA.Backend)

Introduction

This notebook will show how to train a LoRA adapter w/ GPT2. Most of this notebook is copied from https://hexdocs.pm/bumblebee/fine_tuning.html. If you want to learn more about how this training setup works, check that livebook instead. The LoRA specific details are further down below.

Hyperparameters

batch_size = 2
sequence_length = 256
r = 2
lora_alpha = 4
lora_dropout = 0.05

:ok
:ok

Load a model

{:ok, spec} = Bumblebee.load_spec({:hf, "gpt2"})
{:ok, model} = Bumblebee.load_model({:hf, "gpt2"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2"})

:ok

17:24:05.365 [info] TfrtCpuClient created.
:ok

Prepare a dataset

We'll be using some data scraped from Elixirforum. It's structured like this.

<title>Ideas for an Open Source Discord</title>

<author>WildYorkies</author>

I remember seeing on the Elixir ...
<likes>3 likes</likes>
text =
  Req.get!("https://raw.githubusercontent.com/wtedw/lorax/main/data/elixir-discussion.txt").body

:ok
:ok
tokenized_text = %{"input_ids" => input_ids} = Bumblebee.apply_tokenizer(tokenizer, text)
n_tokens = Nx.size(input_ids)
n_train = round(n_tokens * 0.9)
n_val = n_tokens - n_train

train_data =
  for {input_key, tokenized_values} <- tokenized_text, into: %{} do
    {input_key, Nx.slice_along_axis(tokenized_values, 0, n_train, axis: -1)}
  end

test_data =
  for {input_key, tokenized_values} <- tokenized_text, into: %{} do
    {input_key, Nx.slice_along_axis(tokenized_values, n_train, n_val, axis: -1)}
  end
%{
  "attention_mask" => #Nx.Tensor<
    u32[1][383105]
    EXLA.Backend<host:0, 0.2032894552.1868431376.195217>
    [
      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]
    ]
  >,
  "input_ids" => #Nx.Tensor<
    u32[1][383105]
    EXLA.Backend<host:0, 0.2032894552.1868431376.195218>
    [
      [952, 30120, 25, 198, 40, 447, 247, 76, 2111, 284, 2050, 922, 7572, 329, 2615, 42652, 9643, 6725, 13, 198, 40, 481, 467, 4622, 503, 286, 7243, 994, 13, 1002, 345, 447, 247, 260, 655, 3599, 503, 314, 561, 1950, 326, 345, 467, 329, 262, 10314, 26, 42652, ...]
    ]
  >,
  "token_type_ids" => #Nx.Tensor<
    u32[1][383105]
    EXLA.Backend<host:0, 0.2032894552.1868431376.195219>
    [
      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]
    ]
  >
}
defmodule DataStream do
  def get_batch_stream(%{"input_ids" => input_ids} = data, batch_size, block_size, opts \\ []) do
    seed = Keyword.get(opts, :seed, 1337)

    Stream.resource(
      # initialization function
      fn ->
        Nx.Random.key(seed)
      end,
      # generation function
      fn key ->
        {_b, t} = Nx.shape(input_ids)

        data =
          for {k, v} <- data, into: %{} do
            {k, Nx.reshape(v, {t})}
          end

        # ix = list of random starting indices
        {ix, new_key} =
          Nx.Random.randint(key, 0, t - block_size, shape: {batch_size}, type: :u32)

        ix = Nx.to_list(ix)

        # x is map of sliced tensors
        x =
          for {k, tensor} <- data, into: %{} do
            batch_slice =
              ix
              |> Enum.map(fn i -> Nx.slice_along_axis(tensor, i, block_size, axis: -1) end)
              |> Nx.stack()

            {k, batch_slice}
          end

        # y represents all the predicted next tokens (input_ids shifted by 1)
        y =
          ix
          |> Enum.map(fn i ->
            data["input_ids"] |> Nx.slice_along_axis(i + 1, block_size, axis: -1)
          end)
          |> Nx.stack()
          |> Nx.flatten()

        out_data = {x, y}

        {[out_data], new_key}
      end,
      fn _ -> :ok end
    )
  end
end
{:module, DataStream, <<70, 79, 82, 49, 0, 0, 16, ...>>, {:get_batch_stream, 4}}

You can see what a single batch looks like by grabbing 1 from the stream:

train_batch_stream = DataStream.get_batch_stream(train_data, batch_size, sequence_length)
test_batch_stream = DataStream.get_batch_stream(test_data, batch_size, sequence_length)

[{x, y}] = train_batch_stream |> Enum.take(1)
[{x_val, y_val}] = test_batch_stream |> Enum.take(1)

Bumblebee.Tokenizer.decode(tokenizer, x["input_ids"]) |> IO.inspect()
IO.puts("=====")
Bumblebee.Tokenizer.decode(tokenizer, y) |> IO.inspect()

:ok
[" to how Greg’s Event Store works with its competing consumers model. Why would you want to do this? It allows handlers to run at different speeds, typically you have slow async handlers that can lag behind (e.g. sending emails, third party API requets). But you don’t want them to hold up read model projections to minimise query latency.\nAutonomous subscriptions allows you to add new handlers and replay all events from the beginning of time, or restart a handler to rebuild a projection. I’ve implemented a hybrid push/pull model for the Event Store subscriptions where appended events are published to subscribers, but they are buffered per subscriber and use back-pressure to ensure the subscriber isn’t overwhelmed. The subscription falls back to pulling events from the store when it gets too far behind, until caught up again.\nYou could use GenStage for this, but I would recommend using an individual flow pipeline per handler; not one flow for all handlers. Since GenStage's broadcast dispatcher can only go as fast as the slowest consumer. You also want to have any event handlers run from the event store, after the events have been atomically persisted. Appending events to the store should guarantee that a success reply is returned",
 " by a more fundamental principle - of developer happiness and a system that ‘just makes sense’. Of course the syntax is itself inspired by this, but it goes beyond syntax.\nHaving said that, I think José has also tried to stay true to Erlang and this certainly shows when using Elixir.\nWhen would I use Ruby? When I don’t need Elixir When I write a script for the server, or need to put a site up quick, or have a smaller project in mind I would use Ruby. Mainly for two reasons: I know it, and there is a huge community/set of libraries out there. Chances are if you want to do something someone already has in Ruby.\nThat may well change as I learn Elixir. I’m hoping it does actually, as my brain can’t hold too much information so sticking to one language would be preferential for me\n\n<likes>2 likes</likes>\n\n<author>gnat</author>\n\nAstonJ:\nmy brain can’t hold too much information so sticking to one language would be preferential for me\nThat was at least part of where I was coming from in raising the original question. I used to code"]
=====
" how Greg’s Event Store works with its competing consumers model. Why would you want to do this? It allows handlers to run at different speeds, typically you have slow async handlers that can lag behind (e.g. sending emails, third party API requets). But you don’t want them to hold up read model projections to minimise query latency.\nAutonomous subscriptions allows you to add new handlers and replay all events from the beginning of time, or restart a handler to rebuild a projection. I’ve implemented a hybrid push/pull model for the Event Store subscriptions where appended events are published to subscribers, but they are buffered per subscriber and use back-pressure to ensure the subscriber isn’t overwhelmed. The subscription falls back to pulling events from the store when it gets too far behind, until caught up again.\nYou could use GenStage for this, but I would recommend using an individual flow pipeline per handler; not one flow for all handlers. Since GenStage's broadcast dispatcher can only go as fast as the slowest consumer. You also want to have any event handlers run from the event store, after the events have been atomically persisted. Appending events to the store should guarantee that a success reply is returned only a more fundamental principle - of developer happiness and a system that ‘just makes sense’. Of course the syntax is itself inspired by this, but it goes beyond syntax.\nHaving said that, I think José has also tried to stay true to Erlang and this certainly shows when using Elixir.\nWhen would I use Ruby? When I don’t need Elixir When I write a script for the server, or need to put a site up quick, or have a smaller project in mind I would use Ruby. Mainly for two reasons: I know it, and there is a huge community/set of libraries out there. Chances are if you want to do something someone already has in Ruby.\nThat may well change as I learn Elixir. I’m hoping it does actually, as my brain can’t hold too much information so sticking to one language would be preferential for me\n\n<likes>2 likes</likes>\n\n<author>gnat</author>\n\nAstonJ:\nmy brain can’t hold too much information so sticking to one language would be preferential for me\nThat was at least part of where I was coming from in raising the original question. I used to code pretty"
:ok

Train the model

%{model: model, params: params} = model

model
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "container_37"
  nodes: 859
>
[{input, _}] = Enum.take(train_batch_stream, 1)
Axon.get_output_shape(model, input)
%{
  cache: #Axon.None<...>,
  hidden_states: #Axon.None<...>,
  attentions: #Axon.None<...>,
  cross_attentions: #Axon.None<...>,
  logits: {2, 256, 50257}
}

We'll need to freeze the original parameters in our model before injecting. That way we can train only the LoRA nodes.

lora_model =
  model
  |> Axon.freeze()
  |> Lorax.inject(%Lorax.Config{
    r: r,
    alpha: lora_alpha,
    dropout: lora_dropout,
    target_key: true,
    target_query: true,
    target_value: true
  })
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "cache" => nil, "input_embeddings" => {nil, nil, 768}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}}
  outputs: "container_37"
  nodes: 895
>

We'll reshape the usual GPT2 output so that we can use it with categorical_cross_entropy

defmodule CommonTrain do
  import Nx.Defn

  defn custom_predict_fn(model_predict_fn, params, input) do
    %{prediction: preds} = out = model_predict_fn.(params, input)

    # Output of GPT2 model is a map containing logits and other tensors
    logits = preds.logits

    {b, t, c} = Nx.shape(logits)
    reshaped = Nx.reshape(logits, {b * t, c})
    %{out | prediction: reshaped}
  end

  def custom_loss_fn(y_true, y_pred) do
    Axon.Losses.categorical_cross_entropy(y_true, y_pred,
      from_logits: true,
      sparse: true,
      reduction: :mean
    )
  end
end

{init_fn, predict_fn} = Axon.build(lora_model, mode: :train)
custom_predict_fn = &CommonTrain.custom_predict_fn(predict_fn, &1, &2)
custom_loss_fn = &CommonTrain.custom_loss_fn(&1, &2)

lora_merged_params =
  {init_fn, custom_predict_fn}
  |> Axon.Loop.trainer(custom_loss_fn, Polaris.Optimizers.adam(learning_rate: 3.0e-4))
  |> Axon.Loop.run(train_batch_stream, params, epochs: 1, iterations: 300, compiler: EXLA)

:ok

17:42:55.140 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 250, loss: 3.6697373
:ok

Download LoRA params

When training, Axon will return all the parameters needed to run the model. If you want to download the LoRA-only parameters, you'll need to filter for them first before downloading.

# Method #1
lora_only =
  lora_merged_params
  |> Lorax.Params.filter(params)
  |> Lorax.Params.kino_download()

# Method #2
# serialized =
#   Lorax.Params.filter(lora_merged_params, params)
#   |> Nx.serialize()

# File.write!("/<insert path>/test.lorax", serialized)

Testing out text Generation

lora_model_info = %{model: lora_model, params: lora_merged_params, spec: spec}

lora_generation_config =
  Bumblebee.configure(generation_config,
    max_new_tokens: 256,
    strategy: %{type: :multinomial_sampling, top_p: 0.8}
  )

serving =
  Bumblebee.Text.generation(lora_model_info, tokenizer, lora_generation_config,
    compile: [batch_size: 1, sequence_length: 256],
    stream: true,
    defn_options: [compiler: EXLA, lazy_transfers: :always]
  )

Kino.start_child({Nx.Serving, name: Llama, serving: serving})
{:ok, #PID<0.1392.0>}

We'll kickstart the text generation using some <title>text</title> string. This is how the training data is formatted and will trigger the LoRA neurons to activate. With only 300 iterations, the model has already learned how to output text similar to our training data

Nx.Serving.batched_run(Llama, "<title>Elixir 2.0 released") |> Enum.each(&IO.write/1)
</title>

</likes>0 likes</likes>

<author>jlarperendan</author>

The elixir.clipper gem is really very elegant, you can even use it with other compilers if you don't need it in your code base. It really helps here!
<author>ilax13</author>

The name of the library is good! It already provides a lot of benefit from Elixir, so I guess it's useful for intermediate project developers to have one with both Elixir and Elixir 2.0.1.

<likes>0 likes</likes>

<author>benaeutz</author>

Elixir 2.0 is really powerful. Even though there's some backend changes, Elixir 2.0 is one of the few that changes the behavior of main mode and does not change runtime dependencies. So, if you have multiple compilers that can add the language layer or break the runtime dependencies, then the big advantage of Elixir is the compiler ecosystem. In many cases, the exception to this is compiler breaking.
<likes>1 likes</likes>

<author>avnicom<likes>1 likes
:ok