Tinkex.Regularizer.Pipeline (Tinkex v0.3.4)

View Source

Orchestrates regularizer composition and computes structured loss output.

The pipeline coordinates the execution of base loss and regularizer functions, computing the total composed loss and optional gradient norms.

Composition Formula

loss_total = base_loss + Σ(weight_i × regularizer_i)

Execution Flow

  1. Validates inputs (base_loss_fn, regularizer specs)
  2. Executes base loss function
  3. Executes regularizers (optionally in parallel)
  4. Computes gradient norms (if tracking enabled)
  5. Builds structured CustomLossOutput
  6. Emits telemetry events

Telemetry Events

  • [:tinkex, :custom_loss, :start] - Before computation
  • [:tinkex, :custom_loss, :stop] - After successful computation
  • [:tinkex, :custom_loss, :exception] - On failure

Examples

# Base loss only
{:ok, output} = Pipeline.compute(data, logprobs, &my_loss/2)

# With regularizers
{:ok, output} = Pipeline.compute(data, logprobs, &my_loss/2,
  regularizers: [
    %RegularizerSpec{fn: &l1_reg/2, weight: 0.01, name: "l1"},
    %RegularizerSpec{fn: &entropy_reg/2, weight: 0.001, name: "entropy"}
  ],
  track_grad_norms: true,
  parallel: true
)

Summary

Functions

Compute composed loss from base loss and regularizers.

Functions

compute(data, logprobs, base_loss_fn, opts \\ [])

@spec compute(
  [Tinkex.Types.Datum.t()],
  Nx.Tensor.t(),
  base_loss_fn :: function(),
  keyword()
) :: {:ok, Tinkex.Types.CustomLossOutput.t()} | {:error, term()}

Compute composed loss from base loss and regularizers.

Parameters

  • data - List of training Datum structs
  • logprobs - Nx tensor of log probabilities
  • base_loss_fn - Required function (data, logprobs) -> {loss, metrics}
  • opts - Configuration options

Options

  • :regularizers - List of RegularizerSpec (default: [])
  • :track_grad_norms - Compute gradient norms (default: false)
  • :parallel - Run regularizers in parallel (default: true)
  • :timeout - Execution timeout (default: 30_000)

Returns

  • {:ok, CustomLossOutput.t()} on success
  • {:error, {:pipeline_failed, exception}} on failure
  • {:error, term()} for regularizer failures

Examples

{:ok, output} = Pipeline.compute(data, logprobs, base_loss_fn,
  regularizers: regularizers,
  track_grad_norms: true
)

output.loss_total          # Total composed loss
output.regularizer_total   # Sum of regularizer contributions
output.regularizers["l1"]  # Individual regularizer metrics