Tinkex.TrainingClient (Tinkex v0.3.4)

View Source

GenServer that coordinates training operations for a single model.

Requests are sent sequentially within the GenServer while polling is performed concurrently in background tasks. This keeps request ordering deterministic at the cost of blocking the GenServer during the send phase.

Use Tinkex.Types.ModelInput.from_text/2 to turn raw strings into tokenized ModelInput structs before constructing training data. Chat templates are not applied automatically; provide fully formatted text.

Queue State Observer

This client implements Tinkex.QueueStateObserver and automatically logs human-readable warnings when queue state changes indicate rate limiting or capacity issues:

[warning] Training is paused for model-xyz. Reason: concurrent training clients rate limit hit

Logs are debounced to once per 60 seconds per model to avoid spam.

Summary

Functions

Returns a specification to start this module under a supervisor.

Create a sampling client from this training client asynchronously.

Decode token IDs using this training client's tokenizer.

Encode text using this training client's tokenizer.

Run a forward-only pass over the provided data (inference without backward).

Run a forward-backward pass over the provided data.

Compute forward/backward pass with a custom loss function.

Fetch model metadata for the training client.

Get a tokenizer for this training client's model.

Load model weights from a checkpoint (without optimizer state).

Load model weights and optimizer state from a checkpoint.

Perform an optimizer step.

Save model weights as a training checkpoint.

Save weights for sampling and immediately create a SamplingClient.

Save weights for downstream sampling.

Unload the active model and end the session.

Types

t()

@type t() :: pid()

Functions

child_spec(init_arg)

Returns a specification to start this module under a supervisor.

See Supervisor.

create_sampling_client_async(client, model_path, opts \\ [])

@spec create_sampling_client_async(t(), String.t(), keyword()) :: Task.t()

Create a sampling client from this training client asynchronously.

Takes a model_path (checkpoint path) and returns a Task that resolves to a sampling client.

Examples

task = TrainingClient.create_sampling_client_async(training_pid, "tinker://run-1/weights/0001")
{:ok, sampling_pid} = Task.await(task)

decode(client, ids, opts \\ [])

@spec decode(t(), [integer()], keyword()) ::
  {:ok, String.t()} | {:error, Tinkex.Error.t()}

Decode token IDs using this training client's tokenizer.

Convenience wrapper around Tinkex.Tokenizer.decode/3 that automatically resolves the tokenizer from the training client's model info.

Examples

{:ok, text} = TrainingClient.decode(client, [1, 2, 3])

Options

  • :load_fun - Custom tokenizer loader function
  • :info_fun - Custom info fetcher for testing

encode(client, text, opts \\ [])

@spec encode(t(), String.t(), keyword()) ::
  {:ok, [integer()]} | {:error, Tinkex.Error.t()}

Encode text using this training client's tokenizer.

Convenience wrapper around Tinkex.Tokenizer.encode/3 that automatically resolves the tokenizer from the training client's model info.

Examples

{:ok, ids} = TrainingClient.encode(client, "Hello world")

Options

  • :load_fun - Custom tokenizer loader function
  • :info_fun - Custom info fetcher for testing

forward(client, data, loss_fn, opts \\ [])

@spec forward(t(), [map()], atom() | String.t(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Run a forward-only pass over the provided data (inference without backward).

Returns logprobs that can be converted to Nx tensors via TensorData.to_nx/1. Useful for custom loss computation where gradients are computed in Elixir/Nx.

Returns a Task.t() that yields {:ok, %ForwardBackwardOutput{}} or {:error, %Tinkex.Error{}}.

Examples

{:ok, task} = TrainingClient.forward(client, data, :cross_entropy)
{:ok, output} = Task.await(task)

# Access logprobs from output.loss_fn_outputs
[%{"logprobs" => logprobs_data}] = output.loss_fn_outputs
tensor = TensorData.to_nx(%TensorData{...logprobs_data})

forward_backward(client, data, loss_fn, opts \\ [])

@spec forward_backward(t(), [map()], atom() | String.t(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Run a forward-backward pass over the provided data.

Returns a Task.t() that yields {:ok, %ForwardBackwardOutput{}} or {:error, %Tinkex.Error{}}.

forward_backward_custom(client, data, loss_fn, opts \\ [])

@spec forward_backward_custom(
  t(),
  [Tinkex.Types.Datum.t()],
  loss_fn :: ([Tinkex.Types.Datum.t()], [Nx.Tensor.t()] ->
                {Nx.Tensor.t(), map()}),
  keyword()
) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Compute forward/backward pass with a custom loss function.

This mirrors the Python SDK: performs a forward pass to obtain per-datum logprobs, computes a custom loss, turns gradients into synthetic weights, and sends them back via forward_backward/4. The returned ForwardBackwardOutput is compatible with optim_step/2.

Parameters

  • client: TrainingClient pid
  • data: List of training data (Datum structs)
  • loss_fn: (data, logprobs_list) -> {loss_tensor, metrics_map}
    • logprobs_list is a list of Nx tensors, one per datum
  • opts: Options forwarded to the underlying forward/forward_backward requests

Returns

{:ok, Task.t()} that yields {:ok, ForwardBackwardOutput.t()} or {:error, Error.t()}

Examples

{:ok, task} = TrainingClient.forward_backward_custom(
  client, data, &my_loss_fn/2
)
{:ok, %ForwardBackwardOutput{} = output} = Task.await(task)

get_info(client)

@spec get_info(t()) ::
  {:ok, Tinkex.Types.GetInfoResponse.t()} | {:error, Tinkex.Error.t()}

Fetch model metadata for the training client.

Used by tokenizer resolution to obtain model_data.tokenizer_id.

get_telemetry(client)

get_tokenizer(client, opts \\ [])

@spec get_tokenizer(
  t(),
  keyword()
) :: {:ok, Tinkex.Tokenizer.handle()} | {:error, Tinkex.Error.t()}

Get a tokenizer for this training client's model.

Fetches model info to determine the tokenizer ID, applies heuristics (e.g., Llama-3 gating workaround), and loads/caches the tokenizer.

Options

  • :load_fun - Custom tokenizer loader function (default: HuggingFace)
  • :info_fun - Custom info fetcher for testing

Examples

{:ok, _tokenizer} = TrainingClient.get_tokenizer(client)
{:ok, ids} = TrainingClient.encode(client, "Hello world")

Errors

Returns {:error, %Tinkex.Error{}} if:

  • Model info cannot be fetched
  • Tokenizer cannot be loaded

load_state(client, path, opts \\ [])

@spec load_state(t(), String.t(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Load model weights from a checkpoint (without optimizer state).

Returns a Task.t() that yields {:ok, %LoadWeightsResponse{}} or {:error, %Tinkex.Error{}}.

load_state_with_optimizer(client, path, opts \\ [])

@spec load_state_with_optimizer(t(), String.t(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Load model weights and optimizer state from a checkpoint.

Returns a Task.t() that yields {:ok, %LoadWeightsResponse{}} or {:error, %Tinkex.Error{}}.

optim_step(client, adam_params, opts \\ [])

@spec optim_step(t(), map(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Perform an optimizer step.

Returns a Task.t() that yields {:ok, %OptimStepResponse{}} or {:error, %Tinkex.Error{}}.

save_state(client, name, opts \\ [])

@spec save_state(t(), String.t(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Save model weights as a training checkpoint.

Returns a Task.t() that yields {:ok, %SaveWeightsResponse{}} or {:error, %Tinkex.Error{}}.

save_weights_and_get_sampling_client(client, opts \\ [])

@spec save_weights_and_get_sampling_client(
  t(),
  keyword()
) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Save weights for sampling and immediately create a SamplingClient.

Supports both persisted sampler checkpoints (path-based) and ephemeral sampling sessions (sampling_session_id-only responses).

save_weights_and_get_sampling_client_sync(client, opts \\ [])

@spec save_weights_and_get_sampling_client_sync(
  t(),
  keyword()
) :: {:ok, pid()} | {:error, Tinkex.Error.t()}

Synchronous helper for save_weights_and_get_sampling_client/2.

Waits for sampler save + SamplingClient creation and returns the pid directly.

save_weights_for_sampler(client, name, opts \\ [])

@spec save_weights_for_sampler(t(), String.t(), keyword()) ::
  {:ok, Task.t()} | {:error, Tinkex.Error.t()}

Save weights for downstream sampling.

Parameters

  • client - The TrainingClient pid
  • name - Name/path for the saved sampler weights (required)
  • opts - Additional options

Returns a Task.t() that yields {:ok, %SaveWeightsForSamplerResponse{}} or {:error, %Tinkex.Error{}}.

start_link(opts \\ [])

@spec start_link(keyword()) :: GenServer.on_start()

unload_model(client)

@spec unload_model(t()) ::
  {:ok, Tinkex.Types.UnloadModelResponse.t() | map()}
  | {:error, Tinkex.Error.t()}

Unload the active model and end the session.