Forward Inference
View SourceThis guide covers forward-only inference using the TrainingClient.forward/4 API. Forward inference runs a model's forward pass without computing gradients, returning logprobs that can be converted to Nx tensors for custom analysis and loss computation.
Overview
Forward inference differs from the full training loop in a key way:
forward_backward/4: Computes both forward pass (logits → loss) and backward pass (gradients)forward/4: Computes only the forward pass, returning logprobs without gradients
The forward-only API is useful when you need model outputs but don't need the backend to compute gradients. You might compute custom losses in Elixir/Nx, perform model evaluation, or analyze token probabilities.
When to Use Forward Inference
Use forward/4 instead of forward_backward/4 when you need:
- Custom loss computation: Compute losses in Elixir/Nx where gradients will be calculated locally
- Model evaluation: Calculate perplexity, accuracy, or other metrics without training
- Token analysis: Analyze probability distributions over tokens
- Regularizer development: Build custom regularizers that need logprobs but compute their own gradients
- Inference-only workflows: Get model predictions without updating weights
- Performance profiling: Measure forward pass latency without backward overhead
For standard training with built-in loss functions, use forward_backward/4 as shown in the training loop guide.
Quick Start
{:ok, _} = Application.ensure_all_started(:tinkex)
config = Tinkex.Config.new(
api_key: System.fetch_env!("TINKER_API_KEY"),
base_url: System.get_env("TINKER_BASE_URL")
)
{:ok, service} = Tinkex.ServiceClient.start_link(config: config)
{:ok, training} = Tinkex.ServiceClient.create_lora_training_client(service, "meta-llama/Llama-3.1-8B",
lora_config: %Tinkex.Types.LoraConfig{rank: 16}
)
{:ok, model_input} = Tinkex.Types.ModelInput.from_text(
"The capital of France is",
model_name: "meta-llama/Llama-3.1-8B",
training_client: training
)
# Build datum with target tokens
tokens = model_input.chunks |> hd() |> Map.get(:tokens)
datum = Tinkex.Types.Datum.new(%{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: %Tinkex.Types.TensorData{
data: tokens,
dtype: :int64,
shape: [length(tokens)]
},
weights: %Tinkex.Types.TensorData{
data: List.duplicate(1.0, length(tokens)),
dtype: :float32,
shape: [length(tokens)]
}
}
})
# Run forward pass (inference only, no backward)
{:ok, task} = Tinkex.TrainingClient.forward(training, [datum], :cross_entropy)
{:ok, output} = Task.await(task, 60_000)
IO.inspect(output.metrics, label: "metrics")Setting Up TrainingClient for Forward Inference
The setup process is identical to a standard training workflow:
1. Create a Service Client
config = Tinkex.Config.new(
api_key: System.fetch_env!("TINKER_API_KEY")
)
{:ok, service} = Tinkex.ServiceClient.start_link(config: config)2. Create a Training Client
{:ok, training} = Tinkex.ServiceClient.create_lora_training_client(service, "meta-llama/Llama-3.1-8B",
lora_config: %Tinkex.Types.LoraConfig{rank: 16}
)Even though you're doing inference, you still use a TrainingClient because the forward pass operates on the same infrastructure as training.
3. Prepare Input Data
# Option 1: From text (automatic tokenization)
{:ok, model_input} = Tinkex.Types.ModelInput.from_text(
"Your prompt here",
model_name: "meta-llama/Llama-3.1-8B",
training_client: training
)
# Option 2: From token IDs directly
model_input = %Tinkex.Types.ModelInput{
chunks: [
%{
tokens: [1, 450, 6864, 315, 9822, 374],
chunk_index: 0
}
]
}Using TrainingClient.forward/4
The forward/4 function signature is:
@spec forward(t(), [map()], atom() | String.t(), keyword()) ::
{:ok, Task.t()} | {:error, Error.t()}Parameters
- client: The
TrainingClientpid - data: List of
Datumstructs containing model inputs and loss function inputs - loss_fn: Loss function name (e.g.,
:cross_entropy) - determines logprobs format - opts: Optional keyword list for configuration
Return Value
Always returns {:ok, task} where the task yields:
{:ok, %ForwardBackwardOutput{}}on success{:error, %Tinkex.Error{}}on failure
Example Usage
# Build datum with input and target
datum = Tinkex.Types.Datum.new(%{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: to_tensor(target_tokens, :int64),
weights: to_tensor(weights, :float32)
}
})
# Run forward pass
{:ok, task} = Tinkex.TrainingClient.forward(training, [datum], :cross_entropy)
# Await result
{:ok, output} = Task.await(task, 60_000)Options
The opts keyword list supports:
:loss_fn_config- Configuration for the loss function (map):timeout- HTTP request timeout in milliseconds:http_timeout- Alias for:timeout:telemetry_metadata- Additional telemetry metadata (map):await_timeout- Maximum time to wait for task completion
Understanding Logprobs Output
The forward pass returns a ForwardBackwardOutput struct containing logprobs in the loss_fn_outputs field.
Output Structure
%Tinkex.Types.ForwardBackwardOutput{
metrics: %{
"total_loss" => 2.456,
"mean_loss" => 2.456,
# ... other metrics
},
loss_fn_outputs: [
%{
"logprobs" => %{
"data" => [...], # Flat list of float32 values
"dtype" => "float32",
"shape" => [seq_len, vocab_size]
}
}
],
loss_fn_output_type: "cross_entropy"
}Accessing Logprobs
{:ok, output} = Task.await(task)
# Extract first output
[first_output | _] = output.loss_fn_outputs
# Get logprobs structure
%{"logprobs" => logprobs_data} = first_output
# logprobs_data contains:
# - data: flat list of probabilities
# - dtype: tensor data type (usually "float32")
# - shape: [sequence_length, vocabulary_size]Logprobs Format by Loss Function
Different loss functions return logprobs in different formats:
:cross_entropy
%{
"logprobs" => %{
"data" => [0.1, 0.2, ...], # Log probabilities
"dtype" => "float32",
"shape" => [seq_len, vocab_size]
}
}The shape indicates [sequence_length, vocabulary_size], where:
sequence_length: Number of tokens in the inputvocabulary_size: Size of the model's token vocabulary (e.g., 128256 for Llama-3.1)
Working with Nx Tensors
The logprobs can be converted to Nx tensors for numerical operations using TensorData.to_nx/1.
Converting to Nx
alias Tinkex.Types.TensorData
# Extract logprobs from output
[%{"logprobs" => logprobs}] = output.loss_fn_outputs
# Create TensorData struct
tensor_data = %TensorData{
data: logprobs["data"],
dtype: parse_dtype(logprobs["dtype"]),
shape: logprobs["shape"]
}
# Convert to Nx tensor
nx_tensor = TensorData.to_nx(tensor_data)
# Now you can use Nx operations
mean_logprob = Nx.mean(nx_tensor)
max_logprob = Nx.reduce_max(nx_tensor)Data Type Conversion
The dtype field in the API response is a string. Convert it to an Nx-compatible atom:
defp parse_dtype("float32"), do: :float32
defp parse_dtype("float64"), do: :float64
defp parse_dtype("int64"), do: :int64
defp parse_dtype("int32"), do: :int32
defp parse_dtype(atom) when is_atom(atom), do: atom
defp parse_dtype(_), do: :float32 # fallbackNx Operations on Logprobs
Once converted to an Nx tensor, you can perform various operations:
# Basic statistics
mean = Nx.mean(nx_tensor) |> Nx.to_number()
variance = Nx.variance(nx_tensor) |> Nx.to_number()
min_val = Nx.reduce_min(nx_tensor) |> Nx.to_number()
max_val = Nx.reduce_max(nx_tensor) |> Nx.to_number()
# Reshape for per-token analysis
# If shape is [seq_len, vocab_size]
{seq_len, vocab_size} = Nx.shape(nx_tensor)
# Get probabilities for each position
per_token_probs = Nx.slice_along_axis(nx_tensor, 0, 1, axis: 0)
# Softmax to get probability distribution
probs = Nx.exp(nx_tensor) / Nx.sum(Nx.exp(nx_tensor), axes: [1])
# Find most likely token at each position
most_likely = Nx.argmax(nx_tensor, axis: 1)Converting Between Tinkex Types and Nx
TensorData → Nx
# From Tinkex TensorData to Nx tensor
tensor_data = %Tinkex.Types.TensorData{
data: [1.0, 2.0, 3.0, 4.0],
dtype: :float32,
shape: [2, 2]
}
nx_tensor = Tinkex.Types.TensorData.to_nx(tensor_data)
# => #Nx.Tensor<
# f32[2][2]
# [
# [1.0, 2.0],
# [3.0, 4.0]
# ]
# >Nx → TensorData
# From Nx tensor to Tinkex TensorData
nx_tensor = Nx.tensor([[1, 2], [3, 4]], type: :s64)
tensor_data = %Tinkex.Types.TensorData{
data: Nx.to_flat_list(nx_tensor),
dtype: nx_type_to_tinkex(Nx.type(nx_tensor)),
shape: Tuple.to_list(Nx.shape(nx_tensor))
}
defp nx_type_to_tinkex({:s, 64}), do: :int64
defp nx_type_to_tinkex({:s, 32}), do: :int32
defp nx_type_to_tinkex({:f, 32}), do: :float32
defp nx_type_to_tinkex({:f, 64}), do: :float64Building Datum with Nx Tensors
# Helper to convert list to TensorData
defp to_tensor(data, dtype) when is_list(data) do
%Tinkex.Types.TensorData{
data: data,
dtype: dtype,
shape: [length(data)]
}
end
# Or from Nx tensor directly
defp nx_to_tensor_data(nx_tensor) do
%Tinkex.Types.TensorData{
data: Nx.to_flat_list(nx_tensor),
dtype: nx_type_to_tinkex(Nx.type(nx_tensor)),
shape: Tuple.to_list(Nx.shape(nx_tensor))
}
end
# Use in datum construction
target_tensor = Nx.tensor(target_tokens, type: :s64)
weights_tensor = Nx.broadcast(1.0, {length(target_tokens)})
datum = Tinkex.Types.Datum.new(%{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: nx_to_tensor_data(target_tensor),
weights: nx_to_tensor_data(weights_tensor)
}
})Use Cases
1. Model Evaluation
Calculate perplexity on a validation set without training:
defmodule ModelEvaluator do
def evaluate_perplexity(training_client, validation_data) do
results = Enum.map(validation_data, fn {text, _label} ->
{:ok, model_input} = Tinkex.Types.ModelInput.from_text(
text,
model_name: "meta-llama/Llama-3.1-8B",
training_client: training_client
)
tokens = get_tokens(model_input)
datum = build_datum(model_input, tokens)
{:ok, task} = Tinkex.TrainingClient.forward(
training_client,
[datum],
:cross_entropy
)
{:ok, output} = Task.await(task, 60_000)
output.metrics["mean_loss"]
end)
# Perplexity = exp(average loss)
avg_loss = Enum.sum(results) / length(results)
perplexity = :math.exp(avg_loss)
%{
perplexity: perplexity,
average_loss: avg_loss,
sample_count: length(results)
}
end
defp get_tokens(%{chunks: [chunk | _]}), do: chunk.tokens || chunk["tokens"]
defp build_datum(model_input, tokens) do
Tinkex.Types.Datum.new(%{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: to_tensor(tokens, :int64),
weights: to_tensor(List.duplicate(1.0, length(tokens)), :float32)
}
})
end
defp to_tensor(data, dtype) do
%Tinkex.Types.TensorData{
data: data,
dtype: dtype,
shape: [length(data)]
}
end
end
# Usage
perplexity_report = ModelEvaluator.evaluate_perplexity(training, validation_set)
IO.inspect(perplexity_report)
# => %{perplexity: 12.34, average_loss: 2.513, sample_count: 100}2. Token Probability Analysis
Analyze the probability distribution for specific tokens:
defmodule TokenAnalyzer do
alias Tinkex.Types.TensorData
def analyze_token_probabilities(training_client, text, target_word) do
{:ok, model_input} = Tinkex.Types.ModelInput.from_text(
text,
model_name: "meta-llama/Llama-3.1-8B",
training_client: training_client
)
tokens = get_tokens(model_input)
datum = build_datum(model_input, tokens)
{:ok, task} = Tinkex.TrainingClient.forward(
training_client,
[datum],
:cross_entropy
)
{:ok, output} = Task.await(task, 60_000)
# Extract logprobs and convert to Nx
[%{"logprobs" => logprobs_data}] = output.loss_fn_outputs
tensor_data = %TensorData{
data: logprobs_data["data"],
dtype: :float32,
shape: logprobs_data["shape"]
}
logprobs = TensorData.to_nx(tensor_data)
# Get probability distribution for each position
# logprobs shape: [seq_len, vocab_size]
probs = Nx.exp(logprobs)
# Find most likely tokens at each position
most_likely_indices = Nx.argmax(logprobs, axis: 1)
%{
sequence_length: elem(Nx.shape(logprobs), 0),
vocab_size: elem(Nx.shape(logprobs), 1),
most_likely_tokens: Nx.to_flat_list(most_likely_indices),
average_confidence: Nx.mean(Nx.reduce_max(probs, axes: [1])) |> Nx.to_number()
}
end
defp get_tokens(%{chunks: [chunk | _]}), do: chunk.tokens || chunk["tokens"]
defp build_datum(model_input, tokens) do
Tinkex.Types.Datum.new(%{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: %TensorData{data: tokens, dtype: :int64, shape: [length(tokens)]},
weights: %TensorData{data: List.duplicate(1.0, length(tokens)), dtype: :float32, shape: [length(tokens)]}
}
})
end
end
# Usage
analysis = TokenAnalyzer.analyze_token_probabilities(
training,
"The capital of France is Paris",
"Paris"
)
IO.inspect(analysis)3. Custom Loss Computation
Compute a custom loss in Nx with your own gradient logic:
defmodule CustomLoss do
def compute_with_regularization(training_client, data, lambda \\ 0.01) do
# Get logprobs from forward pass
{:ok, task} = Tinkex.TrainingClient.forward(
training_client,
data,
:cross_entropy
)
{:ok, output} = Task.await(task, 60_000)
# Extract logprobs
[%{"logprobs" => logprobs_data}] = output.loss_fn_outputs
logprobs = Tinkex.Types.TensorData.to_nx(%Tinkex.Types.TensorData{
data: logprobs_data["data"],
dtype: :float32,
shape: logprobs_data["shape"]
})
# Compute base cross-entropy loss
base_loss = output.metrics["mean_loss"]
# Add custom L2 regularization in Nx
l2_penalty = lambda * Nx.sum(Nx.pow(logprobs, 2)) / Nx.size(logprobs)
l2_value = Nx.to_number(l2_penalty)
total_loss = base_loss + l2_value
%{
base_loss: base_loss,
l2_penalty: l2_value,
total_loss: total_loss,
logprobs_shape: Nx.shape(logprobs)
}
end
end
# Usage
loss_report = CustomLoss.compute_with_regularization(training, [datum], 0.001)
IO.inspect(loss_report)4. Perplexity Calculation
Calculate perplexity for language model evaluation:
defmodule Perplexity do
def calculate(training_client, text_samples) do
losses = Enum.map(text_samples, fn text ->
{:ok, model_input} = Tinkex.Types.ModelInput.from_text(
text,
model_name: "meta-llama/Llama-3.1-8B",
training_client: training_client
)
tokens = get_first_chunk_tokens(model_input)
datum = %Tinkex.Types.Datum{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: to_tensor(tokens, :int64),
weights: to_tensor(List.duplicate(1.0, length(tokens)), :float32)
}
}
{:ok, task} = Tinkex.TrainingClient.forward(training_client, [datum], :cross_entropy)
{:ok, output} = Task.await(task, 60_000)
output.metrics["mean_loss"]
end)
avg_loss = Enum.sum(losses) / length(losses)
perplexity = :math.exp(avg_loss)
%{
perplexity: perplexity,
average_loss: avg_loss,
num_samples: length(text_samples)
}
end
defp get_first_chunk_tokens(%{chunks: [chunk | _]}) do
Map.get(chunk, :tokens) || Map.get(chunk, "tokens") || []
end
defp to_tensor(data, dtype) when is_list(data) do
%Tinkex.Types.TensorData{
data: data,
dtype: dtype,
shape: [length(data)]
}
end
end
# Usage
samples = [
"The quick brown fox jumps over the lazy dog",
"Machine learning models require large datasets",
"Natural language processing is fascinating"
]
result = Perplexity.calculate(training, samples)
IO.puts("Perplexity: #{result.perplexity}")Performance Considerations
1. Batching and Chunking
The forward/4 function automatically chunks large batches to avoid overwhelming the backend:
- Max chunk size: 128 examples
- Max token count: 500,000 numbers per chunk
Large datasets are automatically split and processed sequentially:
# This gets chunked automatically
large_batch = Enum.map(1..1000, fn i ->
build_datum("Sample text #{i}")
end)
{:ok, task} = Tinkex.TrainingClient.forward(training, large_batch, :cross_entropy)2. Async vs Sync
The forward/4 function returns a Task.t(), allowing async workflows:
# Start multiple forward passes in parallel
tasks = Enum.map(batches, fn batch ->
{:ok, task} = Tinkex.TrainingClient.forward(training, batch, :cross_entropy)
task
end)
# Await all results
results = Enum.map(tasks, &Task.await(&1, 60_000))3. Memory Management
Logprobs tensors can be large (sequence_length × vocab_size):
- Llama-3.1-8B vocabulary size: 128,256
- Sequence length 512: ~512 × 128,256 × 4 bytes = ~250 MB per forward pass
Consider:
# Process in smaller batches
batches = Enum.chunk_every(large_dataset, 10)
results = Enum.map(batches, fn batch ->
{:ok, task} = Tinkex.TrainingClient.forward(training, batch, :cross_entropy)
{:ok, output} = Task.await(task)
# Extract only what you need
metrics = output.metrics
# Let logprobs be garbage collected
metrics
end)4. Timeout Configuration
Adjust timeouts based on data size:
# For large sequences or batches
{:ok, task} = Tinkex.TrainingClient.forward(
training,
large_batch,
:cross_entropy,
timeout: 120_000, # 2 minutes for HTTP request
await_timeout: 180_000 # 3 minutes for task completion
)
{:ok, output} = Task.await(task, 180_000)5. EXLA Backend
Nx operations on logprobs can leverage EXLA for GPU acceleration:
# Set EXLA as default backend
Nx.default_backend(EXLA.Backend)
# Now all Nx operations use GPU if available
logprobs_tensor = TensorData.to_nx(tensor_data)
mean = Nx.mean(logprobs_tensor) # Runs on GPUComplete Example
Here's a complete example that demonstrates forward inference for model evaluation:
defmodule ForwardInferenceExample do
alias Tinkex.Types.{TensorData, Datum, ModelInput}
def run do
# 1. Setup
{:ok, _} = Application.ensure_all_started(:tinkex)
config = Tinkex.Config.new(
api_key: System.fetch_env!("TINKER_API_KEY")
)
{:ok, service} = Tinkex.ServiceClient.start_link(config: config)
{:ok, training} = Tinkex.ServiceClient.create_lora_training_client(
service,
"meta-llama/Llama-3.1-8B",
lora_config: %Tinkex.Types.LoraConfig{rank: 16}
)
# 2. Prepare test data
test_prompts = [
"The capital of France is",
"Machine learning is",
"The quick brown fox"
]
# 3. Run forward inference on each
results = Enum.map(test_prompts, fn prompt ->
analyze_prompt(training, prompt)
end)
# 4. Display results
Enum.each(Enum.zip(test_prompts, results), fn {prompt, result} ->
IO.puts("\nPrompt: #{prompt}")
IO.puts("Loss: #{Float.round(result.loss, 4)}")
IO.puts("Perplexity: #{Float.round(result.perplexity, 2)}")
IO.puts("Tokens: #{result.token_count}")
end)
# 5. Calculate overall statistics
avg_loss = Enum.sum(Enum.map(results, & &1.loss)) / length(results)
avg_perplexity = :math.exp(avg_loss)
IO.puts("\n=== Overall Statistics ===")
IO.puts("Average Loss: #{Float.round(avg_loss, 4)}")
IO.puts("Average Perplexity: #{Float.round(avg_perplexity, 2)}")
end
defp analyze_prompt(training_client, prompt) do
# Tokenize
{:ok, model_input} = ModelInput.from_text(
prompt,
model_name: "meta-llama/Llama-3.1-8B",
training_client: training_client
)
# Get tokens
tokens = get_first_chunk_tokens(model_input)
token_count = length(tokens)
# Build datum
datum = Datum.new(%{
model_input: model_input,
loss_fn_inputs: %{
target_tokens: to_tensor(tokens, :int64),
weights: to_tensor(List.duplicate(1.0, token_count), :float32)
}
})
# Forward pass
{:ok, task} = Tinkex.TrainingClient.forward(
training_client,
[datum],
:cross_entropy
)
{:ok, output} = Task.await(task, 60_000)
# Extract metrics
loss = output.metrics["mean_loss"]
perplexity = :math.exp(loss)
# Optionally analyze logprobs
logprobs_stats = analyze_logprobs(output.loss_fn_outputs)
%{
loss: loss,
perplexity: perplexity,
token_count: token_count,
logprobs_stats: logprobs_stats
}
end
defp analyze_logprobs([%{"logprobs" => logprobs_data}]) do
tensor_data = %TensorData{
data: logprobs_data["data"],
dtype: :float32,
shape: logprobs_data["shape"]
}
tensor = TensorData.to_nx(tensor_data)
%{
shape: Nx.shape(tensor),
mean: Nx.mean(tensor) |> Nx.to_number(),
min: Nx.reduce_min(tensor) |> Nx.to_number(),
max: Nx.reduce_max(tensor) |> Nx.to_number()
}
end
defp analyze_logprobs(_), do: %{}
defp get_first_chunk_tokens(%{chunks: [chunk | _]}) do
Map.get(chunk, :tokens) || Map.get(chunk, "tokens") || []
end
defp to_tensor(data, dtype) when is_list(data) do
%TensorData{data: data, dtype: dtype, shape: [length(data)]}
end
end
# Run the example
ForwardInferenceExample.run()Troubleshooting
Issue: Logprobs not in expected format
Problem: loss_fn_outputs doesn't contain logprobs structure
Solution: Ensure you're using the correct loss function. Cross-entropy returns logprobs:
# Correct
{:ok, task} = Tinkex.TrainingClient.forward(training, data, :cross_entropy)
# Some loss functions may return different output structuresIssue: Nx tensor shape mismatch
Problem: Error when converting TensorData to Nx
Solution: Verify the shape matches the data length:
# Check data consistency
data_len = length(logprobs_data["data"])
shape = logprobs_data["shape"]
expected_len = Enum.reduce(shape, 1, &*/2)
if data_len != expected_len do
IO.puts("Warning: data length #{data_len} != expected #{expected_len}")
endIssue: Out of memory errors
Problem: Large logprobs tensors consume too much memory
Solution: Process in smaller batches and extract only needed values:
# Instead of keeping full tensors
results = Enum.map(large_dataset, fn datum ->
{:ok, task} = Tinkex.TrainingClient.forward(training, [datum], :cross_entropy)
{:ok, output} = Task.await(task)
# Extract only the metric, discard logprobs
output.metrics["mean_loss"]
end)Issue: Timeout during forward pass
Problem: Task times out on large batches
Solution: Increase timeout or reduce batch size:
# Increase timeout
{:ok, task} = Tinkex.TrainingClient.forward(
training,
data,
:cross_entropy,
timeout: 300_000
)
{:ok, output} = Task.await(task, 300_000)
# Or reduce batch size
smaller_batches = Enum.chunk_every(data, 10)See Also
- Training Loop Guide - Full forward-backward training
- Getting Started - Initial setup and configuration
- API Reference - Complete API documentation
- Troubleshooting - Common issues and solutions