Lorax (lorax v0.2.1)

Simple Low-Rank Adaptation (LoRA) implementation

LoRA model creation

To create a LoRA model, freeze an existing model and inject LoRA layers using Lorax.inject/2.

lora_model =
  |> Axon.freeze()
  |> Lorax.inject(%Lorax.Config{
    r: 2,
    alpha: 4,
    dropout: 0.05,
    target_key: false,
    target_query: false,
    target_value: true

For more detailed guides, see

  1. Finetuning LLMs with LoRA
  2. Running LLMs with LoRA

LoRA layers are implemented by injecting new nodes into the Axon struct. These LoRA nodes represent the B and A matrices. Each node takes an input x and computes BAx. Furthermore, the LoRA node will receive Wx as an input and compute Wx + BAx. This isn't the standard implementation, but it simplifies the injection process.

Injection Process

Beginning state

flowchart LR A[input id:0] --> B[target id:1]

Create an empty dummy node

flowchart LR A[input id:0] --> B[target id:1] --> C[dummy id:2]

Create lora node with input ids = [0, 2]

flowchart LR A[input id:0] --> B[target id:1] --> C[dummy id:2] --> E[lora id:3] A[input id:0] --> E[lora id:3]

target takes dummy's id, throw away dummy node

flowchart LR A[input id:0] --> C[target id:2] C[target id:2] --> E[lora id:3] A[input id:0] --> E[lora id:3]

lora takes target's original id

flowchart LR A[input id:0] --> C[target id:2] --> E[lora id:1] A[input id:0] --> E[lora id:1]

lora and target are now swapped. Any downstream node that relied on node id:1 will now receive Wx + BAx



Returns a modified Axon model with LoRA nodes inserted according to the provided configuration.


Link to this function

inject(axon, config)

Returns a modified Axon model with LoRA nodes inserted according to the provided configuration.

target_key, target_query, target_value are required if target_node_fn isn't specified


lora_model =
  |> Axon.freeze()
  |> Lorax.inject(%Lorax.Config{
    r: 2,
    alpha: 4,
    dropout: 0.05,
    target_key: true,
    target_query: true,
    target_value: true

Targeting nodes manually

lora_model =
  |> Axon.freeze()
  |> Lorax.inject(%Lorax.Config{
    r: 2,
    alpha: 4,
    dropout: 0.05,
    target_node_fn: fn %Axon.Node{name: name_fn} ->
      # names are generated lazily, and look like "decoder.blocks.11.self_attention.value"
      # have to invoke the function to see what layer the node represents
      # https://github.com/elixir-nx/axon/blob/v0.6.0/lib/axon.ex#L3923
      name = name_fn.(nil, nil)
      shortname = String.split(name, ".") |> List.last()

      if shortname == "output" do