Test-Time Compute — backbone + scoring network for inference-time scaling.
Implements the "best-of-N" approach to test-time compute scaling: a backbone generates hidden representations, while a parallel scorer network evaluates each position. At inference time, multiple completions are generated and the scorer selects the best one.
Architecture
Input [batch, seq_len, embed_dim]
|
Backbone (GQA + RoPE + SwiGLU, all timesteps)
|
[batch, seq_len, hidden_size]
|
+-- Scorer: dense(scorer_hidden) -> silu -> dense(1)
|
Axon.container(%{backbone: [batch, seq_len, hidden_size],
scores: [batch, seq_len, 1]})Static Utility
select_best_of_n/2 — given scores [N, batch], returns argmax per batch
element, selecting the highest-scoring completion.
Usage
model = TestTimeCompute.build(
embed_dim: 256,
hidden_size: 256,
num_layers: 4,
scorer_hidden: 128
)
# At inference: generate N completions, score each, pick best
best_indices = TestTimeCompute.select_best_of_n(scores)References
- "Scaling LLM Test-Time Compute Optimally" (Snell et al., 2024)
Summary
Functions
Build a Test-Time Compute model with backbone and scorer.
Get the output size of the model (hidden_size of backbone).
Get recommended defaults.
Select the best completion from N candidates using scores.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:num_kv_heads, pos_integer()} | {:scorer_hidden, pos_integer()} | {:dropout, float()}
Options for build/1.
Functions
Build a Test-Time Compute model with backbone and scorer.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Backbone hidden dimension (default: 256):num_layers- Number of backbone transformer layers (default: 4):num_heads- Number of attention heads (default: 4):num_kv_heads- Number of key/value heads for GQA (default: 2):scorer_hidden- Hidden dimension for scorer MLP (default: 128):dropout- Dropout rate (default: 0.1)
Returns
An Axon.container with keys :backbone and :scores.
@spec output_size(keyword()) :: pos_integer()
Get the output size of the model (hidden_size of backbone).
@spec recommended_defaults() :: keyword()
Get recommended defaults.
@spec select_best_of_n( Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Select the best completion from N candidates using scores.
Given scores of shape [N, batch], returns the index of the
highest-scoring candidate per batch element.
Parameters
scores- Tensor of shape[N, batch]with scalar scores per candidateopts- Options (unused, reserved for future use)
Returns
Tensor of shape [batch] with the argmax index per batch element.