Tinkex.TrainingClient (Tinkex v0.3.4)
View SourceGenServer that coordinates training operations for a single model.
Requests are sent sequentially within the GenServer while polling is performed concurrently in background tasks. This keeps request ordering deterministic at the cost of blocking the GenServer during the send phase.
Use Tinkex.Types.ModelInput.from_text/2 to turn raw strings into
tokenized ModelInput structs before constructing training data. Chat
templates are not applied automatically; provide fully formatted text.
Queue State Observer
This client implements Tinkex.QueueStateObserver and automatically logs
human-readable warnings when queue state changes indicate rate limiting
or capacity issues:
[warning] Training is paused for model-xyz. Reason: concurrent training clients rate limit hitLogs are debounced to once per 60 seconds per model to avoid spam.
Summary
Functions
Returns a specification to start this module under a supervisor.
Create a sampling client from this training client asynchronously.
Decode token IDs using this training client's tokenizer.
Encode text using this training client's tokenizer.
Run a forward-only pass over the provided data (inference without backward).
Run a forward-backward pass over the provided data.
Compute forward/backward pass with a custom loss function.
Fetch model metadata for the training client.
Get a tokenizer for this training client's model.
Load model weights from a checkpoint (without optimizer state).
Load model weights and optimizer state from a checkpoint.
Perform an optimizer step.
Save model weights as a training checkpoint.
Save weights for sampling and immediately create a SamplingClient.
Synchronous helper for save_weights_and_get_sampling_client/2.
Save weights for downstream sampling.
Unload the active model and end the session.
Types
@type t() :: pid()
Functions
Returns a specification to start this module under a supervisor.
See Supervisor.
Create a sampling client from this training client asynchronously.
Takes a model_path (checkpoint path) and returns a Task that resolves to a sampling client.
Examples
task = TrainingClient.create_sampling_client_async(training_pid, "tinker://run-1/weights/0001")
{:ok, sampling_pid} = Task.await(task)
@spec decode(t(), [integer()], keyword()) :: {:ok, String.t()} | {:error, Tinkex.Error.t()}
Decode token IDs using this training client's tokenizer.
Convenience wrapper around Tinkex.Tokenizer.decode/3 that automatically
resolves the tokenizer from the training client's model info.
Examples
{:ok, text} = TrainingClient.decode(client, [1, 2, 3])Options
:load_fun- Custom tokenizer loader function:info_fun- Custom info fetcher for testing
@spec encode(t(), String.t(), keyword()) :: {:ok, [integer()]} | {:error, Tinkex.Error.t()}
Encode text using this training client's tokenizer.
Convenience wrapper around Tinkex.Tokenizer.encode/3 that automatically
resolves the tokenizer from the training client's model info.
Examples
{:ok, ids} = TrainingClient.encode(client, "Hello world")Options
:load_fun- Custom tokenizer loader function:info_fun- Custom info fetcher for testing
@spec forward(t(), [map()], atom() | String.t(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Run a forward-only pass over the provided data (inference without backward).
Returns logprobs that can be converted to Nx tensors via TensorData.to_nx/1.
Useful for custom loss computation where gradients are computed in Elixir/Nx.
Returns a Task.t() that yields {:ok, %ForwardBackwardOutput{}} or
{:error, %Tinkex.Error{}}.
Examples
{:ok, task} = TrainingClient.forward(client, data, :cross_entropy)
{:ok, output} = Task.await(task)
# Access logprobs from output.loss_fn_outputs
[%{"logprobs" => logprobs_data}] = output.loss_fn_outputs
tensor = TensorData.to_nx(%TensorData{...logprobs_data})
@spec forward_backward(t(), [map()], atom() | String.t(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Run a forward-backward pass over the provided data.
Returns a Task.t() that yields {:ok, %ForwardBackwardOutput{}} or
{:error, %Tinkex.Error{}}.
@spec forward_backward_custom( t(), [Tinkex.Types.Datum.t()], loss_fn :: ([Tinkex.Types.Datum.t()], [Nx.Tensor.t()] -> {Nx.Tensor.t(), map()}), keyword() ) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Compute forward/backward pass with a custom loss function.
This mirrors the Python SDK: performs a forward pass to obtain per-datum
logprobs, computes a custom loss, turns gradients into synthetic weights,
and sends them back via forward_backward/4. The returned
ForwardBackwardOutput is compatible with optim_step/2.
Parameters
- client: TrainingClient pid
- data: List of training data (Datum structs)
- loss_fn:
(data, logprobs_list) -> {loss_tensor, metrics_map}logprobs_listis a list of Nx tensors, one per datum
- opts: Options forwarded to the underlying forward/forward_backward requests
Returns
{:ok, Task.t()} that yields {:ok, ForwardBackwardOutput.t()} or {:error, Error.t()}
Examples
{:ok, task} = TrainingClient.forward_backward_custom(
client, data, &my_loss_fn/2
)
{:ok, %ForwardBackwardOutput{} = output} = Task.await(task)
@spec get_info(t()) :: {:ok, Tinkex.Types.GetInfoResponse.t()} | {:error, Tinkex.Error.t()}
Fetch model metadata for the training client.
Used by tokenizer resolution to obtain model_data.tokenizer_id.
@spec get_tokenizer( t(), keyword() ) :: {:ok, Tinkex.Tokenizer.handle()} | {:error, Tinkex.Error.t()}
Get a tokenizer for this training client's model.
Fetches model info to determine the tokenizer ID, applies heuristics (e.g., Llama-3 gating workaround), and loads/caches the tokenizer.
Options
:load_fun- Custom tokenizer loader function (default: HuggingFace):info_fun- Custom info fetcher for testing
Examples
{:ok, _tokenizer} = TrainingClient.get_tokenizer(client)
{:ok, ids} = TrainingClient.encode(client, "Hello world")Errors
Returns {:error, %Tinkex.Error{}} if:
- Model info cannot be fetched
- Tokenizer cannot be loaded
@spec load_state(t(), String.t(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Load model weights from a checkpoint (without optimizer state).
Returns a Task.t() that yields {:ok, %LoadWeightsResponse{}} or
{:error, %Tinkex.Error{}}.
@spec load_state_with_optimizer(t(), String.t(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Load model weights and optimizer state from a checkpoint.
Returns a Task.t() that yields {:ok, %LoadWeightsResponse{}} or
{:error, %Tinkex.Error{}}.
@spec optim_step(t(), map(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Perform an optimizer step.
Returns a Task.t() that yields {:ok, %OptimStepResponse{}} or
{:error, %Tinkex.Error{}}.
@spec save_state(t(), String.t(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Save model weights as a training checkpoint.
Returns a Task.t() that yields {:ok, %SaveWeightsResponse{}} or
{:error, %Tinkex.Error{}}.
@spec save_weights_and_get_sampling_client( t(), keyword() ) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Save weights for sampling and immediately create a SamplingClient.
Supports both persisted sampler checkpoints (path-based) and ephemeral sampling sessions (sampling_session_id-only responses).
@spec save_weights_and_get_sampling_client_sync( t(), keyword() ) :: {:ok, pid()} | {:error, Tinkex.Error.t()}
Synchronous helper for save_weights_and_get_sampling_client/2.
Waits for sampler save + SamplingClient creation and returns the pid directly.
@spec save_weights_for_sampler(t(), String.t(), keyword()) :: {:ok, Task.t()} | {:error, Tinkex.Error.t()}
Save weights for downstream sampling.
Parameters
client- The TrainingClient pidname- Name/path for the saved sampler weights (required)opts- Additional options
Returns a Task.t() that yields {:ok, %SaveWeightsForSamplerResponse{}} or
{:error, %Tinkex.Error{}}.
@spec start_link(keyword()) :: GenServer.on_start()
@spec unload_model(t()) :: {:ok, Tinkex.Types.UnloadModelResponse.t() | map()} | {:error, Tinkex.Error.t()}
Unload the active model and end the session.