Lorax (lorax v0.1.0)
Simple Low-Rank Adaptation (LoRA) implementation
LoRA model creation
To create a LoRA model, freeze an existing model and inject LoRA layers using Lorax.inject/2
.
lora_model =
model
|> Axon.freeze()
|> Lorax.inject(%Lorax.Config{
r: 2,
alpha: 4,
dropout: 0.05,
target_key: true,
target_query: true,
target_value: true
})
For more detailed guides, see
LoRA layers are implemented by injecting new nodes into the Axon struct.
These LoRA nodes represent the B and A matrices. Each node takes an input x
and computes BAx
.
Furthermore, the LoRA node will receive Wx
as an input and compute Wx + BAx
.
This isn't the standard implementation, but it simplifies the injection process.
Injection Process
Beginning state
Create an empty dummy node
Create lora node with input ids = [0, 2]
target takes dummy's id, throw away dummy node
lora takes target's original id
lora and target are now swapped.
Any downstream node that relied on node id:1 will now receive Wx + BAx
Summary
Functions
Returns a modified Axon model with LoRA nodes inserted according to the provided configuration.
Functions
inject(axon, config)
Returns a modified Axon model with LoRA nodes inserted according to the provided configuration.
target_key
, target_query
, target_value
are required if target_node_fn
isn't specified
Examples
lora_model =
model
|> Axon.freeze()
|> Lorax.inject(%Lorax.Config{
r: 2,
alpha: 4,
dropout: 0.05,
target_key: true,
target_query: true,
target_value: true
})
Targeting nodes manually
lora_model =
model
|> Axon.freeze()
|> Lorax.inject(%Lorax.Config{
r: 2,
alpha: 4,
dropout: 0.05,
target_node_fn: fn %Axon.Node{name: name_fn} ->
# names are generated lazily, and look like "decoder.blocks.11.self_attention.value"
# have to invoke the function to see what layer the node represents
# https://github.com/elixir-nx/axon/blob/v0.6.0/lib/axon.ex#L3923
name = name_fn.(nil, nil)
shortname = String.split(name, ".") |> List.last()
if shortname == "output" do
true
else
false
end
end
})