Tinkex.Training.CustomLoss (Tinkex v0.3.4)
View SourceCustom loss training helpers.
Mirrors Python forward_backward_custom behavior by computing gradients
with respect to per-datum logprobs and constructing the synthetic dataset
used to send gradients back to the server.
Summary
Functions
Build synthetic data for linearized loss using negative gradients as weights.
Compute gradients of the custom loss with respect to each logprobs tensor.
Extract logprob tensors while preserving the per-datum structure.
Functions
@spec build_linear_loss_data([Tinkex.Types.Datum.t()], [Nx.Tensor.t()]) :: [ Tinkex.Types.Datum.t() ]
Build synthetic data for linearized loss using negative gradients as weights.
Each returned datum includes:
- original
model_input target_tokenscopied from the source datumweightsset to-gradient
@spec compute_gradients(list(), [Nx.Tensor.t()], (list(), [Nx.Tensor.t()] -> {Nx.Tensor.t(), map()})) :: {:ok, {[Nx.Tensor.t()], map()}} | {:error, term()}
Compute gradients of the custom loss with respect to each logprobs tensor.
Uses Nx.Defn.grad/2 to differentiate the user-provided loss function.
Returns gradients in the same order as the input logprobs.
@spec extract_per_datum_logprobs( Tinkex.Types.ForwardBackwardOutput.t() | [Tinkex.Types.ForwardBackwardOutput.t()] ) :: {:ok, [Nx.Tensor.t()]} | {:error, term()}
Extract logprob tensors while preserving the per-datum structure.
Accepts either a single ForwardBackwardOutput or a list of them (when
forward responses were chunked) and returns a list of Nx tensors, one per
datum, in the original order.