Edifice.Meta.TestTimeCompute (Edifice v0.2.0)

Copy Markdown View Source

Test-Time Compute — backbone + scoring network for inference-time scaling.

Implements the "best-of-N" approach to test-time compute scaling: a backbone generates hidden representations, while a parallel scorer network evaluates each position. At inference time, multiple completions are generated and the scorer selects the best one.

Architecture

Input [batch, seq_len, embed_dim]
      |
Backbone (GQA + RoPE + SwiGLU, all timesteps)
      |
[batch, seq_len, hidden_size]
      |
+-- Scorer: dense(scorer_hidden) -> silu -> dense(1)
      |
Axon.container(%{backbone: [batch, seq_len, hidden_size],
                  scores: [batch, seq_len, 1]})

Static Utility

select_best_of_n/2 — given scores [N, batch], returns argmax per batch element, selecting the highest-scoring completion.

Usage

model = TestTimeCompute.build(
  embed_dim: 256,
  hidden_size: 256,
  num_layers: 4,
  scorer_hidden: 128
)

# At inference: generate N completions, score each, pick best
best_indices = TestTimeCompute.select_best_of_n(scores)

References

  • "Scaling LLM Test-Time Compute Optimally" (Snell et al., 2024)

Summary

Types

Options for build/1.

Functions

Build a Test-Time Compute model with backbone and scorer.

Get the output size of the model (hidden_size of backbone).

Get recommended defaults.

Select the best completion from N candidates using scores.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_kv_heads, pos_integer()}
  | {:scorer_hidden, pos_integer()}
  | {:dropout, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Test-Time Compute model with backbone and scorer.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Backbone hidden dimension (default: 256)
  • :num_layers - Number of backbone transformer layers (default: 4)
  • :num_heads - Number of attention heads (default: 4)
  • :num_kv_heads - Number of key/value heads for GQA (default: 2)
  • :scorer_hidden - Hidden dimension for scorer MLP (default: 128)
  • :dropout - Dropout rate (default: 0.1)

Returns

An Axon.container with keys :backbone and :scores.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the model (hidden_size of backbone).

select_best_of_n(scores, opts \\ [])

@spec select_best_of_n(
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Select the best completion from N candidates using scores.

Given scores of shape [N, batch], returns the index of the highest-scoring candidate per batch element.

Parameters

  • scores - Tensor of shape [N, batch] with scalar scores per candidate
  • opts - Options (unused, reserved for future use)

Returns

Tensor of shape [batch] with the argmax index per batch element.