KTO: Kahneman-Tversky Optimization for RLHF from binary feedback.
KTO aligns language models using only binary thumbs-up / thumbs-down labels — no preference pairs required. It is based on Kahneman-Tversky Prospect Theory, modelling how humans perceive gains (desirable outputs) and losses (undesirable outputs) asymmetrically.
Key Innovation Over DPO
DPO needs paired (chosen, rejected) responses to the same prompt. KTO only needs a single response labelled desirable (1) or undesirable (0), which is far easier to collect at scale.
Loss Formula
For each sample with policy log-probability log π(y|x) and
reference log π_ref(y|x):
KL = β · (log π(y|x) − log π_ref(y|x))
z = E_batch[KL] # partition-function estimate
# Prospect-theory value functions
desirable: v = σ(KL − z) → utility of a good response
undesirable: v = σ(−KL − z) → disutility of a bad response
# Per-sample loss (minimise negative utility)
L = −λ_D · σ(KL − z) if label = 1 (desirable)
L = −λ_U · σ(−KL − z) if label = 0 (undesirable)λ_D and λ_U allow weighting desirable vs. undesirable feedback.
Architecture
KTO wraps any sequence model; build/1 constructs a decoder-only
transformer for language modelling (identical backbone to DPO).
Prompt x
|
+--------+---------+
| |
v v
+----------+ +----------+
| Policy | | Ref | (frozen)
| π(y|x) | | π_ref(y|x)|
+----------+ +----------+
| |
log_pi log_ref
| |
+--------+---------+
|
v
KTO Loss + binary labelUsage
policy = KTO.build(vocab_size: 32000, hidden_size: 512, num_layers: 6)
loss = KTO.kto_loss(
policy_logprobs,
ref_logprobs,
labels,
beta: 0.1,
desirable_weight: 1.0,
undesirable_weight: 1.0
)References
- Ethayarajh et al., "KTO: Model Alignment as Prospect Theoretic Optimization" (2023)
- https://arxiv.org/abs/2402.01306
Summary
Functions
Build a KTO policy model (decoder-only language model).
Compute per-sequence log-probabilities from logits and target tokens.
Default beta (KL regularisation coefficient) for KTO.
Compute the KTO loss from binary preference labels.
Types
@type build_opt() :: {:dropout, float()} | {:hidden_size, pos_integer()} | {:max_seq_len, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:vocab_size, pos_integer()}
Options for build/1.
@type kto_loss_opt() :: {:beta, float()} | {:desirable_weight, float()} | {:undesirable_weight, float()}
Options for kto_loss/4.
Functions
Build a KTO policy model (decoder-only language model).
Options
:hidden_size- Model hidden dimension (default: 512):num_layers- Number of transformer layers (default: 6):num_heads- Number of attention heads (default: 8):vocab_size- Vocabulary size (default: 32000):dropout- Dropout rate (default: 0.1):max_seq_len- Maximum sequence length for positional encoding (default: 2048)
Returns
An Axon model taking token indices [batch, seq_len] and returning
logits [batch, seq_len, vocab_size].
@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) :: Nx.Tensor.t()
Compute per-sequence log-probabilities from logits and target tokens.
Parameters
logits- Model output[batch, seq_len, vocab_size]targets- Target token indices[batch, seq_len]mask- Optional padding mask[batch, seq_len]; 1 = valid token
Returns
Per-sequence log-probability sums [batch].
@spec default_beta() :: float()
Default beta (KL regularisation coefficient) for KTO.
@spec kto_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), [kto_loss_opt()]) :: Nx.Tensor.t()
Compute the KTO loss from binary preference labels.
Parameters
policy_logprobs- Log-probs from the policy being trained[batch]ref_logprobs- Log-probs from the frozen reference model[batch]labels- Binary labels[batch]; 1 = desirable, 0 = undesirable
Options
:beta- KL regularisation strength (default: 0.1):desirable_weight- Loss weight λ_D for desirable samples (default: 1.0):undesirable_weight- Loss weight λ_U for undesirable samples (default: 1.0)
Returns
Scalar loss value.
Formula
KL = β · (log π − log π_ref)
z_ref = mean(KL) # batch estimate of partition function
reward = KL − z_ref
loss_i = −λ_D · σ(reward_i) if label_i = 1
loss_i = −λ_U · σ(−reward_i) if label_i = 0
L = mean(loss_i)