Edifice.Meta.DPO (Edifice v0.2.0)

Copy Markdown View Source

DPO: Direct Preference Optimization.

Implements DPO from "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (Rafailov et al., NeurIPS 2023). DPO eliminates the need for a separate reward model in RLHF by directly optimizing the policy from preference pairs.

Key Innovation: Implicit Reward Model

Standard RLHF requires:

  1. Train a reward model on preference data
  2. Use RL (PPO) to optimize policy against reward model

DPO shows that for a given reward function r(x,y), the optimal policy is:

pi*(y|x) = (1/Z(x)) * pi_ref(y|x) * exp(r(x,y)/beta)

Inverting this gives the implicit reward:

r(y|x) = beta * log(pi*(y|x)/pi_ref(y|x)) + beta*log(Z(x))

Substituting into the Bradley-Terry preference model and simplifying yields the DPO loss:

L_DPO = -log(sigmoid(beta * (log(pi(y_w|x)/pi_ref(y_w|x))
                            - log(pi(y_l|x)/pi_ref(y_l|x)))))

Where:

  • y_w is the preferred (winning) response
  • y_l is the dispreferred (losing) response
  • pi is the policy being trained
  • pi_ref is the frozen reference policy
  • beta controls KL regularization strength

Architecture

DPO wraps any sequence model (typically a decoder-only transformer):

Input prompt x
      |
      +-------+-------+
      |               |
      v               v
+------------+   +------------+
| Policy     |   | Reference  |  (frozen copy)
| pi(y|x)    |   | pi_ref(y|x)|
+------------+   +------------+
      |               |
      v               v
log_probs_pi    log_probs_ref
      |               |
      +-------+-------+
              |
              v
      DPO Loss Computation

Usage

# Build a DPO-wrapped policy
policy = DPO.build(
  backbone: :decoder_only,
  hidden_size: 512,
  num_layers: 6,
  vocab_size: 32000
)

# Compute DPO loss
loss = DPO.loss(
  policy_logprobs_chosen,
  policy_logprobs_rejected,
  ref_logprobs_chosen,
  ref_logprobs_rejected,
  beta: 0.1
)

Reference

  • Paper: "Direct Preference Optimization: Your Language Model is Secretly a Reward Model"
  • Authors: Rafailov et al.
  • arXiv: https://arxiv.org/abs/2305.18290
  • NeurIPS 2023

Summary

Types

Options for build/1.

Functions

Build a DPO policy model.

Compute per-token log probabilities from logits.

Get the default beta parameter for DPO.

Get recommended defaults for DPO.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:vocab_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a DPO policy model.

The model wraps a decoder-only backbone for language modeling. During training, you'll need to maintain a frozen copy as the reference.

Options

  • :hidden_size - Hidden dimension (default: 512)
  • :num_layers - Number of transformer layers (default: 6)
  • :num_heads - Number of attention heads (default: 8)
  • :vocab_size - Vocabulary size (default: 32000)
  • :dropout - Dropout rate (default: 0.1)
  • :max_seq_len - Maximum sequence length (default: 2048)

Returns

An Axon model that outputs logits over the vocabulary.

compute_logprobs(logits, targets, mask \\ nil)

@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) ::
  Nx.Tensor.t()

Compute per-token log probabilities from logits.

Parameters

  • logits - Model output logits [batch, seq_len, vocab_size]
  • targets - Target token indices [batch, seq_len]
  • mask - Optional mask for padding [batch, seq_len] (1 for valid, 0 for padding)

Returns

Per-sequence log probability sums [batch].

default_beta()

@spec default_beta() :: float()

Get the default beta parameter for DPO.

loss(policy_logprobs_chosen, policy_logprobs_rejected, ref_logprobs_chosen, ref_logprobs_rejected, opts \\ [])

Compute the DPO loss given policy and reference log probabilities.

Parameters

  • policy_logprobs_chosen - Log probs from policy for chosen responses [batch]
  • policy_logprobs_rejected - Log probs from policy for rejected responses [batch]
  • ref_logprobs_chosen - Log probs from reference for chosen responses [batch]
  • ref_logprobs_rejected - Log probs from reference for rejected responses [batch]
  • opts - Options including :beta (default: 0.1)

Returns

Scalar loss value.

Formula

L = -log(sigmoid(beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l))))