Tinkex.TrainingClient.Operations (Tinkex v0.3.4)
View SourceRequest building and sending operations for TrainingClient.
This module handles:
- Building and sending forward/backward requests
- Building and sending optimizer step requests
- Building and sending save/load weights requests
- Creating sampling clients from saved weights
- Handling response parsing and future extraction
Summary
Functions
Build linear loss data safely with validation.
Ensure a model exists, creating one if necessary.
Handle load state response, polling if necessary.
Handle save state response, polling if necessary.
Handle save weights for sampler response, polling if necessary.
Merge custom loss metrics into ForwardBackwardOutput.
Normalize save weights options, generating sampling_session_seq_id if needed.
Poll forward futures for custom loss computation.
Send backward pass for custom loss computation.
Send a forward-backward request and return a future.
Send a forward-only request and return a future.
Send a load state (weights) request.
Send an optimizer step request and return a future.
Send a save state (weights) request.
Send a save weights for sampler request.
Start a sampling client from save response data.
Functions
@spec build_linear_loss_data_safe(list(), [Nx.Tensor.t()]) :: {:ok, list()} | {:error, Tinkex.Error.t()}
Build linear loss data safely with validation.
@spec ensure_model(keyword(), String.t(), integer(), map(), module(), map()) :: {:ok, String.t()} | {:error, Tinkex.Error.t()}
Ensure a model exists, creating one if necessary.
Returns {:ok, model_id} on success.
@spec handle_load_state_response( map() | Tinkex.Types.LoadWeightsResponse.t(), map(), keyword() ) :: {:ok, Tinkex.Types.LoadWeightsResponse.t() | map()} | {:error, Tinkex.Error.t()}
Handle load state response, polling if necessary.
@spec handle_save_state_response( map() | Tinkex.Types.SaveWeightsResponse.t(), map(), keyword() ) :: {:ok, Tinkex.Types.SaveWeightsResponse.t() | map()} | {:error, Tinkex.Error.t()}
Handle save state response, polling if necessary.
@spec handle_save_weights_response( map() | Tinkex.Types.SaveWeightsForSamplerResponse.t(), map(), keyword() ) :: {:ok, Tinkex.Types.SaveWeightsForSamplerResponse.t() | map()} | {:error, Tinkex.Error.t()}
Handle save weights for sampler response, polling if necessary.
@spec merge_custom_metrics(Tinkex.Types.ForwardBackwardOutput.t(), map()) :: Tinkex.Types.ForwardBackwardOutput.t()
Merge custom loss metrics into ForwardBackwardOutput.
Normalize save weights options, generating sampling_session_seq_id if needed.
Returns {normalized_opts, new_counter}.
@spec poll_forward_custom_loss([map()], keyword(), map()) :: {:ok, [Tinkex.Types.ForwardBackwardOutput.t()]} | {:error, Tinkex.Error.t()}
Poll forward futures for custom loss computation.
@spec send_backward_for_custom_loss(list(), [integer()], keyword(), map()) :: {:ok, [Tinkex.Types.ForwardBackwardOutput.t()]} | {:error, Tinkex.Error.t()}
Send backward pass for custom loss computation.
@spec send_forward_backward_request( list(), atom() | String.t(), integer(), keyword(), map() ) :: {:ok, map()} | {:error, Tinkex.Error.t()}
Send a forward-backward request and return a future.
@spec send_forward_request(list(), atom() | String.t(), integer(), keyword(), map()) :: {:ok, map()} | {:error, Tinkex.Error.t()}
Send a forward-only request and return a future.
@spec send_load_state_request(String.t(), boolean(), integer(), keyword(), map()) :: {:ok, map() | Tinkex.Types.LoadWeightsResponse.t()} | {:error, Tinkex.Error.t()}
Send a load state (weights) request.
@spec send_optim_step_request(map(), integer(), keyword(), map()) :: {:ok, map()} | {:error, Tinkex.Error.t()}
Send an optimizer step request and return a future.
@spec send_save_state_request(String.t(), integer(), keyword(), map()) :: {:ok, map() | Tinkex.Types.SaveWeightsResponse.t()} | {:error, Tinkex.Error.t()}
Send a save state (weights) request.
@spec send_save_weights_for_sampler_request(integer(), keyword(), map()) :: {:ok, map() | Tinkex.Types.SaveWeightsForSamplerResponse.t()} | {:error, Tinkex.Error.t()}
Send a save weights for sampler request.
@spec start_sampling_client_from_save( Tinkex.Types.SaveWeightsForSamplerResponse.t() | map(), integer(), keyword(), map() ) :: {:ok, pid()} | {:error, Tinkex.Error.t() | any()}
Start a sampling client from save response data.
Handles both path-based and sampling_session_id-based responses.