Tinkex.Training.CustomLoss (Tinkex v0.3.4)

View Source

Custom loss training helpers.

Mirrors Python forward_backward_custom behavior by computing gradients with respect to per-datum logprobs and constructing the synthetic dataset used to send gradients back to the server.

Summary

Functions

Build synthetic data for linearized loss using negative gradients as weights.

Compute gradients of the custom loss with respect to each logprobs tensor.

Extract logprob tensors while preserving the per-datum structure.

Functions

build_linear_loss_data(original_data, gradients)

@spec build_linear_loss_data([Tinkex.Types.Datum.t()], [Nx.Tensor.t()]) :: [
  Tinkex.Types.Datum.t()
]

Build synthetic data for linearized loss using negative gradients as weights.

Each returned datum includes:

  • original model_input
  • target_tokens copied from the source datum
  • weights set to -gradient

compute_gradients(data, logprobs_list, loss_fn)

@spec compute_gradients(list(), [Nx.Tensor.t()], (list(), [Nx.Tensor.t()] ->
                                              {Nx.Tensor.t(), map()})) ::
  {:ok, {[Nx.Tensor.t()], map()}} | {:error, term()}

Compute gradients of the custom loss with respect to each logprobs tensor.

Uses Nx.Defn.grad/2 to differentiate the user-provided loss function. Returns gradients in the same order as the input logprobs.

extract_per_datum_logprobs(outputs)

@spec extract_per_datum_logprobs(
  Tinkex.Types.ForwardBackwardOutput.t()
  | [Tinkex.Types.ForwardBackwardOutput.t()]
) :: {:ok, [Nx.Tensor.t()]} | {:error, term()}

Extract logprob tensors while preserving the per-datum structure.

Accepts either a single ForwardBackwardOutput or a list of them (when forward responses were chunked) and returns a list of Nx tensors, one per datum, in the original order.