Custom Loss Training
View SourceTinkex supports custom loss training that mirrors the Python SDK’s forward_backward_custom_async: it performs a forward pass to obtain per-datum logprobs, runs your Nx loss function, computes gradients, sends them back as synthetic weights, and returns a ForwardBackwardOutput that can be passed directly into optim_step/2.
Prerequisites
TINKER_API_KEYexported- Optional:
TINKER_BASE_URL,TINKER_BASE_MODEL - A training client (e.g.,
ServiceClient.create_lora_training_client/3)
Loss Function Signature
loss_fn :: (list(Datum.t()), [Nx.Tensor.t()] -> {Nx.Tensor.t(), map()})- The second argument is a list of logprob tensors, one per datum (no flattening).
- Return a scalar loss tensor and a metrics map. Tensor metrics are converted via
Nx.to_number/1.
Minimal Workflow
loss_fn = fn _data, [logprobs] ->
nll = Nx.negate(Nx.mean(logprobs))
{nll, %{"custom_perplexity" => Nx.exp(nll) |> Nx.to_number()}}
end
{:ok, task} = Tinkex.TrainingClient.forward_backward_custom(training_client, data, loss_fn)
{:ok, %Tinkex.Types.ForwardBackwardOutput{} = out} = Task.await(task)
IO.inspect(out.metrics) # includes "loss" plus your custom metrics
{:ok, adam} = Tinkex.Types.AdamParams.new(learning_rate: 1.0e-4)
{:ok, step_task} = Tinkex.TrainingClient.optim_step(training_client, adam)
{:ok, _resp} = Task.await(step_task)The backend receives gradients as weights in a synthetic cross-entropy pass, so training actually occurs. The returned output is compatible with optim_step/2 and any downstream metric reducers.
End-to-End Example
Run the live example with your API key:
TINKER_API_KEY=... mix run examples/custom_loss_training.exs
This script:
- Builds a LoRA training client
- Prepares a datum with
target_tokens - Runs
forward_backward_custom/4with a user-defined loss - Executes
optim_step/2on the resulting gradients