# `Edifice.Meta.DPO`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/meta/dpo.ex#L1)

DPO: Direct Preference Optimization.

Implements DPO from "Direct Preference Optimization: Your Language Model is
Secretly a Reward Model" (Rafailov et al., NeurIPS 2023). DPO eliminates
the need for a separate reward model in RLHF by directly optimizing the
policy from preference pairs.

## Key Innovation: Implicit Reward Model

Standard RLHF requires:
1. Train a reward model on preference data
2. Use RL (PPO) to optimize policy against reward model

DPO shows that for a given reward function r(x,y), the optimal policy is:

```
pi*(y|x) = (1/Z(x)) * pi_ref(y|x) * exp(r(x,y)/beta)
```

Inverting this gives the implicit reward:

```
r(y|x) = beta * log(pi*(y|x)/pi_ref(y|x)) + beta*log(Z(x))
```

Substituting into the Bradley-Terry preference model and simplifying yields
the DPO loss:

```
L_DPO = -log(sigmoid(beta * (log(pi(y_w|x)/pi_ref(y_w|x))
                            - log(pi(y_l|x)/pi_ref(y_l|x)))))
```

Where:
- y_w is the preferred (winning) response
- y_l is the dispreferred (losing) response
- pi is the policy being trained
- pi_ref is the frozen reference policy
- beta controls KL regularization strength

## Architecture

DPO wraps any sequence model (typically a decoder-only transformer):

```
Input prompt x
      |
      +-------+-------+
      |               |
      v               v
+------------+   +------------+
| Policy     |   | Reference  |  (frozen copy)
| pi(y|x)    |   | pi_ref(y|x)|
+------------+   +------------+
      |               |
      v               v
log_probs_pi    log_probs_ref
      |               |
      +-------+-------+
              |
              v
      DPO Loss Computation
```

## Usage

    # Build a DPO-wrapped policy
    policy = DPO.build(
      backbone: :decoder_only,
      hidden_size: 512,
      num_layers: 6,
      vocab_size: 32000
    )

    # Compute DPO loss
    loss = DPO.loss(
      policy_logprobs_chosen,
      policy_logprobs_rejected,
      ref_logprobs_chosen,
      ref_logprobs_rejected,
      beta: 0.1
    )

## Reference

- Paper: "Direct Preference Optimization: Your Language Model is Secretly a Reward Model"
- Authors: Rafailov et al.
- arXiv: https://arxiv.org/abs/2305.18290
- NeurIPS 2023

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:vocab_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a DPO policy model.

The model wraps a decoder-only backbone for language modeling.
During training, you'll need to maintain a frozen copy as the reference.

## Options

  - `:hidden_size` - Hidden dimension (default: 512)
  - `:num_layers` - Number of transformer layers (default: 6)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:vocab_size` - Vocabulary size (default: 32000)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:max_seq_len` - Maximum sequence length (default: 2048)

## Returns

  An Axon model that outputs logits over the vocabulary.

# `compute_logprobs`

```elixir
@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) ::
  Nx.Tensor.t()
```

Compute per-token log probabilities from logits.

## Parameters

  - `logits` - Model output logits [batch, seq_len, vocab_size]
  - `targets` - Target token indices [batch, seq_len]
  - `mask` - Optional mask for padding [batch, seq_len] (1 for valid, 0 for padding)

## Returns

  Per-sequence log probability sums [batch].

# `default_beta`

```elixir
@spec default_beta() :: float()
```

Get the default beta parameter for DPO.

# `loss`

```elixir
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()
```

Compute the DPO loss given policy and reference log probabilities.

## Parameters

  - `policy_logprobs_chosen` - Log probs from policy for chosen responses [batch]
  - `policy_logprobs_rejected` - Log probs from policy for rejected responses [batch]
  - `ref_logprobs_chosen` - Log probs from reference for chosen responses [batch]
  - `ref_logprobs_rejected` - Log probs from reference for rejected responses [batch]
  - `opts` - Options including `:beta` (default: 0.1)

## Returns

  Scalar loss value.

## Formula

```
L = -log(sigmoid(beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l))))
```

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for DPO.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
