View Source Fine-tuning

Mix.install([
  {:bumblebee, "~> 0.5.0"},
  {:nx, "~> 0.7.0"},
  {:exla, "~> 0.7.0"},
  {:axon, "~> 0.6.1"},
  {:explorer, "~> 0.7.0"}
])

Nx.default_backend(EXLA.Backend)

Introduction

Fine-tuning is the process of specializing the parameters in a pre-trained model to a specific task. Large-language models such as BERT train on a generic langauge-modeling task which makes them powerful at extracting features from text. Despite their power, you often still need to train them on a downstream task.

This example demonstrates how to use Bumblebee and Axon to fine-tune a pre-trained Bert model to classify Yelp reviews into classes of 1-5 stars. This example is based on Fine-tune a pretrained model from Huggingface.

You'll need to first download the Yelp Reviews dataset (download).

Once downloaded, extract it to a directory of your choosing and you're ready to go!

Load a model

We'll start by loading a pre-trained model and tokenizer; however, we'll initialize the model to have an untrained sequence classification head.

Reviews in the dataset can have anywhere from 1 to 5 stars, which means we need 5 labels in our sequence classification head. We can change the default configuration by loading the model spec with Bumblebee.load_spec/2 and making changes to spec properties with Bumblebee.configure/2.

The pre-trained model we'll be using is bert-base-cased; however, you can use any of the supported models from the HuggingFace Hub.

{:ok, spec} =
  Bumblebee.load_spec({:hf, "google-bert/bert-base-cased"},
    architecture: :for_sequence_classification
  )

spec = Bumblebee.configure(spec, num_labels: 5)

{:ok, model} = Bumblebee.load_model({:hf, "google-bert/bert-base-cased"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google-bert/bert-base-cased"})

14:41:33.314 [info] TfrtCpuClient created.

14:41:33.820 [debug] the following parameters were missing:

  * sequence_classification_head.output.kernel
  * sequence_classification_head.output.bias


14:41:33.820 [debug] the following PyTorch parameters were unused:

  * cls.predictions.bias
  * cls.predictions.decoder.weight
  * cls.predictions.transform.LayerNorm.beta
  * cls.predictions.transform.LayerNorm.gamma
  * cls.predictions.transform.dense.bias
  * cls.predictions.transform.dense.weight
  * cls.seq_relationship.bias
  * cls.seq_relationship.weight
{:ok,
 %Bumblebee.Text.BertTokenizer{
   tokenizer: #Tokenizers.Tokenizer<[
     vocab_size: 28996,
     continuing_subword_prefix: "##",
     max_input_chars_per_word: 100,
     model_type: "bpe",
     unk_token: "[UNK]"
   ]>,
   special_tokens: %{cls: "[CLS]", mask: "[MASK]", pad: "[PAD]", sep: "[SEP]", unk: "[UNK]"}
 }}

Prepare a dataset

With the models downloaded and ready to go, you need to prepare the dataset. The downloaded dataset is a CSV. You can use the Explorer library to quickly load the CSV into a DataFrame.

Once the data is loaded, you need to convert raw text to tokens and the raw labels to tensors. Additionally, you need to convert the DataFrame to a Stream consisting of tuples: {tokenized, labels} - that is the form expected by Axon's training API.

defmodule Yelp do
  def load(path, tokenizer, opts \\ []) do
    path
    |> Explorer.DataFrame.from_csv!(header: false)
    |> Explorer.DataFrame.rename(["label", "text"])
    |> stream()
    |> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
  end

  def stream(df) do
    xs = df["text"]
    ys = df["label"]

    xs
    |> Explorer.Series.to_enum()
    |> Stream.zip(Explorer.Series.to_enum(ys))
  end

  def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
    tokenizer = Bumblebee.configure(tokenizer, length: sequence_length)

    stream
    |> Stream.chunk_every(batch_size)
    |> Stream.map(fn batch ->
      {text, labels} = Enum.unzip(batch)
      tokenized = Bumblebee.apply_tokenizer(tokenizer, text)
      {tokenized, Nx.stack(labels)}
    end)
  end
end
{:module, Yelp, <<70, 79, 82, 49, 0, 0, 13, ...>>, {:tokenize_and_batch, 4}}

Now you can use the Yelp.load/2 function to load a training set and a testing set:

batch_size = 32
sequence_length = 64

train_data =
  Yelp.load("~/yelp/yelp_review_full_csv/train.csv", tokenizer,
    batch_size: batch_size,
    sequence_length: sequence_length
  )

test_data =
  Yelp.load("~/yelp/yelp_review_full_csv/test.csv", tokenizer,
    batch_size: batch_size,
    sequence_length: sequence_length
  )
#Stream<[
  enum: #Stream<[
    enum: #Function<73.124013645/2 in Stream.zip_with/2>,
    funs: [#Function<3.124013645/1 in Stream.chunk_while/4>]
  ]>,
  funs: [#Function<48.124013645/1 in Stream.map/2>]
]>

You can see what a single batch looks like by grabbing 1 from the stream:

Enum.take(train_data, 1)
[
  {%{
     "attention_mask" => #Nx.Tensor<
       s64[32][64]
       EXLA.Backend<host:0, 0.801663575.1558315030.66545>
       [
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...],
         ...
       ]
     >,
     "input_ids" => #Nx.Tensor<
       s64[32][64]
       EXLA.Backend<host:0, 0.801663575.1558315030.66544>
       [
         [101, 173, 1197, 119, 2284, 2953, 3272, 1917, 178, 1440, 1111, 1107, 170, 1704, 22351, 119, 1119, 112, 188, 3505, 1105, 3123, 1106, 2037, 1106, 1443, 1217, 10063, 4404, 132, 1119, 112, 188, 1579, 1113, 1159, 1107, 3195, 1117, 4420, 132, 1119, 112, 188, 6559, 1114, ...],
         ...
       ]
     >,
     "token_type_ids" => #Nx.Tensor<
       s64[32][64]
       EXLA.Backend<host:0, 0.801663575.1558315030.66546>
       [
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...],
         ...
       ]
     >
   },
   #Nx.Tensor<
     s64[32]
     EXLA.Backend<host:0, 0.801663575.1558315028.66360>
     [5, 2, 4, 4, 1, 5, 5, 1, 2, 3, 1, 1, 4, 2, 5, 5, 5, 5, 5, 5, 4, 3, 2, 5, 1, 1, 1, 2, 2, 4, 2, 2]
   >}
]

The dataset is rather large for CPU training, so we'll just train a small subset (250 training batches and 50 testing batches):

train_data = Enum.take(train_data, 250)
test_data = Enum.take(test_data, 50)
:ok
:ok

Train the model

Now we can go about training the model! First, we need to extract the Axon model and parameters from the Bumblebee model map:

%{model: model, params: params} = model

model
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}, "token_type_ids" => {nil, nil}}
  outputs: "container_37"
  nodes: 790
>

The Axon model actually outputs a map with :logits, :hidden_states, and :attentions. You can see this by using Axon.get_output_shape/2 with an input. This method symbolically executes the graph and gets the resulting shapes:

[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(model, input)
%{attentions: #Axon.None<...>, hidden_states: #Axon.None<...>, logits: {32, 5}}

For training, we only care about the :logits key, so we'll extract that by attaching an Axon.nx/2 layer to the model:

logits_model = Axon.nx(model, & &1.logits)
#Axon<
  inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}, "token_type_ids" => {nil, nil}}
  outputs: "nx_0"
  nodes: 791
>

Now we can declare our training loop. You can construct Axon training loops using the Axon.Loop.trainer/3 factory method with a model, loss function, and optimizer. We'll also adjust the log-settings to more frequently log metrics to standard out:

loss =
  &Axon.Losses.categorical_cross_entropy(&1, &2,
    reduction: :mean,
    from_logits: true,
    sparse: true
  )

optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)

loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.3813108/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.14409478/1 in Axon.Loop.log/3>,
       #Function<6.14409478/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.14409478/1 in Axon.Loop.log/3>,
       #Function<64.14409478/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

The call to trainer just returns a data structure. In Axon, we manipulate this data structure to control different parts of the loop. For example, you can attach metrics:

accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)

loop = Axon.Loop.metric(loop, accuracy, "accuracy")
#Axon.Loop<
  metrics: %{
    "accuracy" => {#Function<11.3813108/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>},
    "loss" => {#Function<11.3813108/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.14409478/1 in Axon.Loop.log/3>,
       #Function<6.14409478/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.14409478/1 in Axon.Loop.log/3>,
       #Function<64.14409478/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

And you can attach event handlers to do certain things, such as serialize the loop state at regular intervals so you don't lose your progress:

loop = Axon.Loop.checkpoint(loop, event: :epoch_completed)
#Axon.Loop<
  metrics: %{
    "accuracy" => {#Function<11.3813108/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>},
    "loss" => {#Function<11.3813108/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<17.14409478/1 in Axon.Loop.checkpoint/2>,
       #Function<6.14409478/2 in Axon.Loop.build_filter_fn/1>},
      {#Function<27.14409478/1 in Axon.Loop.log/3>,
       #Function<6.14409478/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.14409478/1 in Axon.Loop.log/3>,
       #Function<64.14409478/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

To run the loop, you just need to call Axon.Loop.run/4. Axon.Loop.run/4 takes a loop, input data, and any initial state (in this case initial parameters). You can kind of think of Axon.Loop.run/4 as an Enum.reduce/3. It takes data, an accumulator, and a function - which map to Loop.run/4 input data, initial state, and the actual loop data structure.

You'll commonly see loops written out in long chains using Elixir's |> operator, like this:

trained_model_state =
  logits_model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(accuracy, "accuracy")
  |> Axon.Loop.checkpoint(event: :epoch_completed)
  |> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA, strict?: false)

:ok

02:46:02.170 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 249, accuracy: 0.3462500 loss: 1.2216607
Epoch: 1, Batch: 249, accuracy: 0.5186251 loss: 1.0558304
Epoch: 2, Batch: 249, accuracy: 0.6236249 loss: 0.9317472
:ok

Evaluating the model

The training loop returns the final model state after training over your dataset for the given number of epochs. Axon uses the same Axon.Loop API to create evaluation loops as well. You can create one with the Axon.Loop.evaluator/1 factory, instrument it with metrics, and run it on your data with your trained model state:

logits_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(accuracy, "accuracy")
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
Batch: 49, accuracy: 0.3675000
%{
  0 => %{
    "accuracy" => #Nx.Tensor<
      f32
      EXLA.Backend<host:0, 0.446408219.3911319572.169449>
      0.36750003695487976
    >
  }
}