Tinkex.Regularizer behaviour (Tinkex v0.3.1)
View SourceBehaviour for implementing regularizers.
Regularizers can be implemented as:
- Anonymous functions matching the callback spec
- Modules implementing this behaviour
- Tasks for async operations
Implementing a Regularizer Module
defmodule MyRegularizers.L1Sparsity do
@behaviour Tinkex.Regularizer
@impl true
def compute(_data, logprobs, _opts) do
l1 = Nx.sum(Nx.abs(logprobs))
{l1, %{"l1_value" => Nx.to_number(l1)}}
end
@impl true
def name, do: "l1_sparsity"
endUsing as Anonymous Function
regularizer_spec = %RegularizerSpec{
fn: fn _data, logprobs ->
{Nx.sum(Nx.abs(logprobs)), %{}}
end,
weight: 0.01,
name: "l1"
}Using with RegularizerSpec
Regularizers are typically used via RegularizerSpec structs passed to
the regularizer pipeline:
%RegularizerSpec{
fn: &MyRegularizers.L1Sparsity.compute/3,
weight: 0.01,
name: MyRegularizers.L1Sparsity.name()
}
Summary
Callbacks
Compute the regularizer loss and metrics.
Return the regularizer name for telemetry and logging.
Functions
Execute a regularizer (function or module) and return results.
Callbacks
@callback compute( data :: [Tinkex.Types.Datum.t()], logprobs :: Nx.Tensor.t(), opts :: keyword() ) :: {Nx.Tensor.t(), %{required(String.t()) => number()}}
Compute the regularizer loss and metrics.
Parameters
- data: List of training Datum structs
- logprobs: Nx tensor of log probabilities from forward pass
- opts: Optional keyword configuration
Returns
Tuple of {loss_tensor, metrics_map} where:
- loss_tensor: Scalar Nx tensor representing the regularizer loss
- metrics_map: Map of string keys to numeric values for telemetry
@callback name() :: String.t()
Return the regularizer name for telemetry and logging.
This callback is optional. If not implemented, the name should be provided via the RegularizerSpec.
Functions
@spec execute( fn_or_module :: function() | module(), data :: [Tinkex.Types.Datum.t()], logprobs :: Nx.Tensor.t(), opts :: keyword() ) :: {Nx.Tensor.t(), %{required(String.t()) => number()}}
Execute a regularizer (function or module) and return results.
Handles both anonymous functions and behaviour-implementing modules.
Parameters
fn_or_module- Either a function or a module implementing the behaviourdata- List of training Datum structslogprobs- Nx tensor of log probabilitiesopts- Optional keyword configuration
Returns
Tuple of {loss_tensor, metrics_map}
Examples
# With anonymous function (arity 2)
Regularizer.execute(
fn _data, logprobs -> {Nx.sum(logprobs), %{}} end,
data,
logprobs
)
# With module
Regularizer.execute(MyRegularizer, data, logprobs, timeout: 5000)