GRPO: Group Relative Policy Optimization.
Implements GRPO from DeepSeek's RLHF methodology. GRPO simplifies RLHF by:
- Sampling G completions per prompt
- Ranking completions within each group
- Using group-relative advantages (no critic network needed)
Key Innovation: Group-Relative Advantage
Standard PPO requires a value function (critic) to compute advantages:
A(s,a) = R - V(s) # Advantage = Return - Value estimateGRPO eliminates the critic by using group-relative normalization:
For each prompt, sample G responses with rewards r_1, ..., r_G
Normalize: A_i = (r_i - mean(r)) / std(r)This works because:
- Within a group, the mean reward is a natural baseline
- Standard deviation normalizes the scale
- No learned value function needed
Algorithm
For each prompt x:
1. Sample G responses: y_1, ..., y_G ~ pi(y|x)
2. Get rewards: r_1, ..., r_G = Reward(x, y_i)
3. Compute advantages: A_i = (r_i - mean(r)) / (std(r) + eps)
4. Policy gradient: L = -sum(A_i * log pi(y_i|x))Architecture
Prompt x [batch]
|
v
+------------------+
| Sample G times | -> G responses per prompt
+------------------+
|
v
+------------------+
| Reward Model | -> Rewards r_1, ..., r_G
+------------------+
|
v
+------------------+
| Group Normalize | -> Advantages A_1, ..., A_G
+------------------+
|
v
+------------------+
| Policy Gradient | -> Update policy
+------------------+Usage
# Build a GRPO policy
policy = GRPO.build(
hidden_size: 512,
num_layers: 6,
vocab_size: 32000
)
# Compute group-relative advantages
advantages = GRPO.compute_advantages(group_rewards)
# Compute policy gradient loss
loss = GRPO.loss(log_probs, advantages, mask)Reference
- DeepSeek-R1 Technical Report (2024)
- "DeepSeekMath: Pushing the Limits of Mathematical Reasoning" (2024)
Summary
Functions
Build a GRPO policy model.
Compute group-relative advantages from rewards.
Compute per-token log probabilities from logits.
Get the default group size for GRPO.
Compute the GRPO policy gradient loss.
Get recommended defaults for GRPO.
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:max_seq_len, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:vocab_size, pos_integer()}
Options for build/1.
Functions
Build a GRPO policy model.
Uses the same architecture as DPO (decoder-only transformer).
Options
:hidden_size- Hidden dimension (default: 512):num_layers- Number of transformer layers (default: 6):num_heads- Number of attention heads (default: 8):vocab_size- Vocabulary size (default: 32000):dropout- Dropout rate (default: 0.1):max_seq_len- Maximum sequence length (default: 2048)
Returns
An Axon model that outputs logits over the vocabulary.
@spec compute_advantages( Nx.Tensor.t(), keyword() ) :: Nx.Tensor.t()
Compute group-relative advantages from rewards.
For each group of G responses to the same prompt, compute:
A_i = (r_i - mean(r)) / (std(r) + eps)Parameters
rewards- Reward tensor [batch, group_size] or [batch * group_size]opts- Options::group_size- Number of responses per prompt (default: inferred or 8):eps- Epsilon for numerical stability (default: 1e-8)
Returns
Advantages tensor with same shape as rewards.
@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) :: Nx.Tensor.t()
Compute per-token log probabilities from logits.
Same as DPO.compute_logprobs/3.
Parameters
logits- Model output logits [batch, seq_len, vocab_size]targets- Target token indices [batch, seq_len]mask- Optional mask for padding [batch, seq_len]
Returns
Per-sequence log probability sums [batch].
@spec default_group_size() :: pos_integer()
Get the default group size for GRPO.
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute the GRPO policy gradient loss.
Parameters
log_probs- Per-sequence log probabilities [batch] or [batch, group_size]advantages- Group-relative advantages [batch] or [batch, group_size]opts- Options::clip_range- PPO-style clipping range (optional, default: nil for no clipping):old_log_probs- Old policy log probs for PPO clipping (required if clip_range set)
Returns
Scalar loss value.
Formula
Without clipping:
L = -mean(A * log_pi)With clipping (PPO-style):
ratio = exp(log_pi - old_log_pi)
L = -mean(min(ratio * A, clip(ratio, 1-eps, 1+eps) * A))
@spec recommended_defaults() :: keyword()
Get recommended defaults for GRPO.