Tinkex.Types.CustomLossOutput (Tinkex v0.3.4)

View Source

Structured output from custom loss computation with regularizers.

This type mirrors the Python SDK's metrics schema for API compatibility, providing comprehensive telemetry for research workflows.

Schema

%CustomLossOutput{
  loss_total: 2.847,
  base_loss: %{
    value: 2.5,
    grad_norm: 3.14,
    custom: %{"perplexity" => 12.18}
  },
  regularizers: %{
    "sparsity" => %RegularizerOutput{...},
    "entropy" => %RegularizerOutput{...}
  },
  regularizer_total: 0.347,
  total_grad_norm: 5.67
}

Loss Composition

The total loss is computed as:

loss_total = base_loss + Σ(weight_i × regularizer_i_loss)

Each regularizer's contribution is weight * value.

Summary

Functions

Build CustomLossOutput from computation results.

Get the primary loss value (for backward compatibility).

Types

base_loss_metrics()

@type base_loss_metrics() :: %{
  value: float(),
  grad_norm: float() | nil,
  custom: %{required(String.t()) => number()}
}

t()

@type t() :: %Tinkex.Types.CustomLossOutput{
  base_loss: base_loss_metrics() | nil,
  loss_total: float(),
  regularizer_total: float() | nil,
  regularizers: %{required(String.t()) => Tinkex.Types.RegularizerOutput.t()},
  total_grad_norm: float() | nil
}

Functions

build(base_loss_value, base_loss_metrics, regularizer_outputs, opts \\ [])

@spec build(
  base_loss_value :: float(),
  base_loss_metrics :: map() | nil,
  regularizer_outputs :: [Tinkex.Types.RegularizerOutput.t()],
  opts :: keyword()
) :: t()

Build CustomLossOutput from computation results.

Parameters

  • base_loss_value - The primary loss value
  • base_loss_metrics - Custom metrics from base loss function
  • regularizer_outputs - List of RegularizerOutput structs
  • opts - Optional: :base_grad_norm, :total_grad_norm

Examples

CustomLossOutput.build(2.5, %{"nll" => 2.5}, regularizer_outputs,
  base_grad_norm: 3.14,
  total_grad_norm: 5.67
)

loss(custom_loss_output)

@spec loss(t()) :: float()

Get the primary loss value (for backward compatibility).

Equivalent to accessing output.loss_total.