ExFairness.Mitigation.Reweighting (ExFairness v0.5.1)
View SourceSample reweighting for fairness-aware machine learning.
Reweighting is a pre-processing technique that assigns different weights to training samples to achieve fairness. Samples from underrepresented groups or combinations receive higher weights.
How It Works
For demographic parity, the weight for sample (a, y) is:
w(a, y) = P(Y = y) / P(A = a, Y = y)This ensures that all group-label combinations have equal expected weight, which helps achieve demographic parity after reweighting.
For equalized odds, weights are computed to balance both positive and negative outcomes across groups.
Usage
Compute weights during data preparation, then pass them to your training
algorithm's sample_weight parameter.
References
- Kamiran, F., & Calders, T. (2012). "Data preprocessing techniques for classification without discrimination." KAIS.
Examples
iex> labels = Nx.tensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
iex> sensitive = Nx.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
iex> weights = ExFairness.Mitigation.Reweighting.compute_weights(labels, sensitive)
iex> Nx.size(weights)
20
Summary
Functions
Computes sample weights for fairness-aware training.
Functions
@spec compute_weights(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Computes sample weights for fairness-aware training.
Parameters
labels- Binary labels tensor (0 or 1)sensitive_attr- Binary sensitive attribute tensor (0 or 1)opts- Options::target- Target fairness metric (:demographic_parityor:equalized_odds, default::demographic_parity):min_per_group- Minimum samples per group for validation (default: 10)
Returns
A tensor of sample weights (same shape as labels). Weights are normalized to have mean 1.0.
Examples
iex> labels = Nx.tensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
iex> sensitive = Nx.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
iex> weights = ExFairness.Mitigation.Reweighting.compute_weights(labels, sensitive)
iex> mean = Nx.mean(weights) |> Nx.to_number()
iex> Float.round(mean, 2)
1.0