Edifice.Meta.GRPO (Edifice v0.2.0)

Copy Markdown View Source

GRPO: Group Relative Policy Optimization.

Implements GRPO from DeepSeek's RLHF methodology. GRPO simplifies RLHF by:

  1. Sampling G completions per prompt
  2. Ranking completions within each group
  3. Using group-relative advantages (no critic network needed)

Key Innovation: Group-Relative Advantage

Standard PPO requires a value function (critic) to compute advantages:

A(s,a) = R - V(s)  # Advantage = Return - Value estimate

GRPO eliminates the critic by using group-relative normalization:

For each prompt, sample G responses with rewards r_1, ..., r_G
Normalize: A_i = (r_i - mean(r)) / std(r)

This works because:

  • Within a group, the mean reward is a natural baseline
  • Standard deviation normalizes the scale
  • No learned value function needed

Algorithm

For each prompt x:
    1. Sample G responses: y_1, ..., y_G ~ pi(y|x)
    2. Get rewards: r_1, ..., r_G = Reward(x, y_i)
    3. Compute advantages: A_i = (r_i - mean(r)) / (std(r) + eps)
    4. Policy gradient: L = -sum(A_i * log pi(y_i|x))

Architecture

Prompt x [batch]
      |
      v
+------------------+
| Sample G times   |  -> G responses per prompt
+------------------+
      |
      v
+------------------+
| Reward Model     |  -> Rewards r_1, ..., r_G
+------------------+
      |
      v
+------------------+
| Group Normalize  |  -> Advantages A_1, ..., A_G
+------------------+
      |
      v
+------------------+
| Policy Gradient  |  -> Update policy
+------------------+

Usage

# Build a GRPO policy
policy = GRPO.build(
  hidden_size: 512,
  num_layers: 6,
  vocab_size: 32000
)

# Compute group-relative advantages
advantages = GRPO.compute_advantages(group_rewards)

# Compute policy gradient loss
loss = GRPO.loss(log_probs, advantages, mask)

Reference

  • DeepSeek-R1 Technical Report (2024)
  • "DeepSeekMath: Pushing the Limits of Mathematical Reasoning" (2024)

Summary

Types

Options for build/1.

Functions

Build a GRPO policy model.

Compute group-relative advantages from rewards.

Compute per-token log probabilities from logits.

Get the default group size for GRPO.

Compute the GRPO policy gradient loss.

Get recommended defaults for GRPO.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:vocab_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a GRPO policy model.

Uses the same architecture as DPO (decoder-only transformer).

Options

  • :hidden_size - Hidden dimension (default: 512)
  • :num_layers - Number of transformer layers (default: 6)
  • :num_heads - Number of attention heads (default: 8)
  • :vocab_size - Vocabulary size (default: 32000)
  • :dropout - Dropout rate (default: 0.1)
  • :max_seq_len - Maximum sequence length (default: 2048)

Returns

An Axon model that outputs logits over the vocabulary.

compute_advantages(rewards, opts \\ [])

@spec compute_advantages(
  Nx.Tensor.t(),
  keyword()
) :: Nx.Tensor.t()

Compute group-relative advantages from rewards.

For each group of G responses to the same prompt, compute:

A_i = (r_i - mean(r)) / (std(r) + eps)

Parameters

  • rewards - Reward tensor [batch, group_size] or [batch * group_size]
  • opts - Options:
    • :group_size - Number of responses per prompt (default: inferred or 8)
    • :eps - Epsilon for numerical stability (default: 1e-8)

Returns

Advantages tensor with same shape as rewards.

compute_logprobs(logits, targets, mask \\ nil)

@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) ::
  Nx.Tensor.t()

Compute per-token log probabilities from logits.

Same as DPO.compute_logprobs/3.

Parameters

  • logits - Model output logits [batch, seq_len, vocab_size]
  • targets - Target token indices [batch, seq_len]
  • mask - Optional mask for padding [batch, seq_len]

Returns

Per-sequence log probability sums [batch].

default_group_size()

@spec default_group_size() :: pos_integer()

Get the default group size for GRPO.

loss(log_probs, advantages, opts \\ [])

@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute the GRPO policy gradient loss.

Parameters

  • log_probs - Per-sequence log probabilities [batch] or [batch, group_size]
  • advantages - Group-relative advantages [batch] or [batch, group_size]
  • opts - Options:
    • :clip_range - PPO-style clipping range (optional, default: nil for no clipping)
    • :old_log_probs - Old policy log probs for PPO clipping (required if clip_range set)

Returns

Scalar loss value.

Formula

Without clipping:

L = -mean(A * log_pi)

With clipping (PPO-style):

ratio = exp(log_pi - old_log_pi)
L = -mean(min(ratio * A, clip(ratio, 1-eps, 1+eps) * A))