# `Edifice.Meta.KTO`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/meta/kto.ex#L1)

KTO: Kahneman-Tversky Optimization for RLHF from binary feedback.

KTO aligns language models using only binary thumbs-up / thumbs-down
labels — no preference pairs required. It is based on Kahneman-Tversky
Prospect Theory, modelling how humans perceive gains (desirable outputs)
and losses (undesirable outputs) asymmetrically.

## Key Innovation Over DPO

DPO needs *paired* (chosen, rejected) responses to the same prompt.
KTO only needs a single response labelled desirable (1) or undesirable
(0), which is far easier to collect at scale.

## Loss Formula

For each sample with policy log-probability `log π(y|x)` and
reference `log π_ref(y|x)`:

```
KL  = β · (log π(y|x) − log π_ref(y|x))
z   = E_batch[KL]     # partition-function estimate

# Prospect-theory value functions
desirable:    v = σ(KL − z)   → utility of a good response
undesirable:  v = σ(−KL − z)  → disutility of a bad response

# Per-sample loss (minimise negative utility)
L = −λ_D · σ(KL − z)       if label = 1 (desirable)
L = −λ_U · σ(−KL − z)      if label = 0 (undesirable)
```

`λ_D` and `λ_U` allow weighting desirable vs. undesirable feedback.

## Architecture

KTO wraps any sequence model; `build/1` constructs a decoder-only
transformer for language modelling (identical backbone to DPO).

```
Prompt x
    |
    +--------+---------+
    |                  |
    v                  v
+----------+     +----------+
|  Policy  |     |  Ref     |  (frozen)
| π(y|x)   |     | π_ref(y|x)|
+----------+     +----------+
    |                  |
log_pi             log_ref
    |                  |
    +--------+---------+
             |
             v
      KTO Loss + binary label
```

## Usage

    policy = KTO.build(vocab_size: 32000, hidden_size: 512, num_layers: 6)

    loss = KTO.kto_loss(
      policy_logprobs,
      ref_logprobs,
      labels,
      beta: 0.1,
      desirable_weight: 1.0,
      undesirable_weight: 1.0
    )

## References

- Ethayarajh et al., "KTO: Model Alignment as Prospect Theoretic Optimization" (2023)
- https://arxiv.org/abs/2402.01306

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:vocab_size, pos_integer()}
```

Options for `build/1`.

# `kto_loss_opt`

```elixir
@type kto_loss_opt() ::
  {:beta, float()}
  | {:desirable_weight, float()}
  | {:undesirable_weight, float()}
```

Options for `kto_loss/4`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a KTO policy model (decoder-only language model).

## Options

  - `:hidden_size` - Model hidden dimension (default: 512)
  - `:num_layers` - Number of transformer layers (default: 6)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:vocab_size` - Vocabulary size (default: 32000)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:max_seq_len` - Maximum sequence length for positional encoding (default: 2048)

## Returns

  An Axon model taking token indices `[batch, seq_len]` and returning
  logits `[batch, seq_len, vocab_size]`.

# `compute_logprobs`

```elixir
@spec compute_logprobs(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t() | nil) ::
  Nx.Tensor.t()
```

Compute per-sequence log-probabilities from logits and target tokens.

## Parameters

  - `logits` - Model output `[batch, seq_len, vocab_size]`
  - `targets` - Target token indices `[batch, seq_len]`
  - `mask` - Optional padding mask `[batch, seq_len]`; 1 = valid token

## Returns

  Per-sequence log-probability sums `[batch]`.

# `default_beta`

```elixir
@spec default_beta() :: float()
```

Default beta (KL regularisation coefficient) for KTO.

# `kto_loss`

```elixir
@spec kto_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), [kto_loss_opt()]) ::
  Nx.Tensor.t()
```

Compute the KTO loss from binary preference labels.

## Parameters

  - `policy_logprobs` - Log-probs from the policy being trained `[batch]`
  - `ref_logprobs` - Log-probs from the frozen reference model `[batch]`
  - `labels` - Binary labels `[batch]`; 1 = desirable, 0 = undesirable

## Options

  - `:beta` - KL regularisation strength (default: 0.1)
  - `:desirable_weight` - Loss weight λ_D for desirable samples (default: 1.0)
  - `:undesirable_weight` - Loss weight λ_U for undesirable samples (default: 1.0)

## Returns

  Scalar loss value.

## Formula

```
KL     = β · (log π − log π_ref)
z_ref  = mean(KL)            # batch estimate of partition function
reward = KL − z_ref

loss_i = −λ_D · σ(reward_i)     if label_i = 1
loss_i = −λ_U · σ(−reward_i)    if label_i = 0
L = mean(loss_i)
```

---

*Consult [api-reference.md](api-reference.md) for complete listing*
