Edifice.Meta.KTO (Edifice v0.2.0)

Copy Markdown View Source

KTO: Kahneman-Tversky Optimization for RLHF from binary feedback.

KTO aligns language models using only binary thumbs-up / thumbs-down labels — no preference pairs required. It is based on Kahneman-Tversky Prospect Theory, modelling how humans perceive gains (desirable outputs) and losses (undesirable outputs) asymmetrically.

Key Innovation Over DPO

DPO needs paired (chosen, rejected) responses to the same prompt. KTO only needs a single response labelled desirable (1) or undesirable (0), which is far easier to collect at scale.

Loss Formula

For each sample with policy log-probability log π(y|x) and reference log π_ref(y|x):

KL  = β · (log π(y|x)  log π_ref(y|x))
z   = E_batch[KL]     # partition-function estimate

# Prospect-theory value functions
desirable:    v = σ(KL  z)    utility of a good response
undesirable:  v = σ(KL  z)   disutility of a bad response

# Per-sample loss (minimise negative utility)
L = λ_D · σ(KL  z)       if label = 1 (desirable)
L = λ_U · σ(KL  z)      if label = 0 (undesirable)

λ_D and λ_U allow weighting desirable vs. undesirable feedback.

Architecture

KTO wraps any sequence model; build/1 constructs a decoder-only transformer for language modelling (identical backbone to DPO).

Prompt x
    |
    +--------+---------+
    |                  |
    v                  v
+----------+     +----------+
|  Policy  |     |  Ref     |  (frozen)
| π(y|x)   |     | π_ref(y|x)|
+----------+     +----------+
    |                  |
log_pi             log_ref
    |                  |
    +--------+---------+
             |
             v
      KTO Loss + binary label

Usage

policy = KTO.build(vocab_size: 32000, hidden_size: 512, num_layers: 6)

loss = KTO.kto_loss(
  policy_logprobs,
  ref_logprobs,
  labels,
  beta: 0.1,
  desirable_weight: 1.0,
  undesirable_weight: 1.0
)

References

Summary

Types

Options for build/1.

Options for kto_loss/4.

Functions

Build a KTO policy model (decoder-only language model).

Compute per-sequence log-probabilities from logits and target tokens.

Default beta (KL regularisation coefficient) for KTO.

Compute the KTO loss from binary preference labels.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:vocab_size, pos_integer()}

Options for build/1.

kto_loss_opt()

@type kto_loss_opt() ::
  {:beta, float()}
  | {:desirable_weight, float()}
  | {:undesirable_weight, float()}

Options for kto_loss/4.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a KTO policy model (decoder-only language model).

Options

  • :hidden_size - Model hidden dimension (default: 512)
  • :num_layers - Number of transformer layers (default: 6)
  • :num_heads - Number of attention heads (default: 8)
  • :vocab_size - Vocabulary size (default: 32000)
  • :dropout - Dropout rate (default: 0.1)
  • :max_seq_len - Maximum sequence length for positional encoding (default: 2048)

Returns

An Axon model taking token indices [batch, seq_len] and returning logits [batch, seq_len, vocab_size].

compute_logprobs(logits, targets, mask \\ nil)

@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) ::
  Nx.Tensor.t()

Compute per-sequence log-probabilities from logits and target tokens.

Parameters

  • logits - Model output [batch, seq_len, vocab_size]
  • targets - Target token indices [batch, seq_len]
  • mask - Optional padding mask [batch, seq_len]; 1 = valid token

Returns

Per-sequence log-probability sums [batch].

default_beta()

@spec default_beta() :: float()

Default beta (KL regularisation coefficient) for KTO.

kto_loss(policy_logprobs, ref_logprobs, labels, opts \\ [])

@spec kto_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), [kto_loss_opt()]) ::
  Nx.Tensor.t()

Compute the KTO loss from binary preference labels.

Parameters

  • policy_logprobs - Log-probs from the policy being trained [batch]
  • ref_logprobs - Log-probs from the frozen reference model [batch]
  • labels - Binary labels [batch]; 1 = desirable, 0 = undesirable

Options

  • :beta - KL regularisation strength (default: 0.1)
  • :desirable_weight - Loss weight λ_D for desirable samples (default: 1.0)
  • :undesirable_weight - Loss weight λ_U for undesirable samples (default: 1.0)

Returns

Scalar loss value.

Formula

KL     = β · (log π  log π_ref)
z_ref  = mean(KL)            # batch estimate of partition function
reward = KL  z_ref

loss_i = λ_D · σ(reward_i)     if label_i = 1
loss_i = λ_U · σ(reward_i)    if label_i = 0
L = mean(loss_i)