Edifice.Probabilistic.MCDropout (Edifice v0.2.0)

Copy Markdown View Source

MC Dropout for uncertainty estimation (Gal & Ghahramani, 2016).

Standard dropout is disabled at inference time. MC Dropout keeps dropout active during inference and runs multiple forward passes to estimate prediction uncertainty. This provides a practical Bayesian approximation without modifying the training procedure.

How It Works

  1. Train a standard network with dropout (nothing special)
  2. At inference time, keep dropout ON
  3. Run N forward passes with different dropout masks
  4. Mean of outputs = prediction, Variance = uncertainty

Interpretation

  • Low variance: Model is confident (consistent predictions across dropout masks)
  • High variance: Model is uncertain (different subnetworks disagree)
  • Out-of-distribution: Typically shows high variance

Architecture

Input [batch, input_size]
      |
      v
+-------------------------------+
| Dense + ReLU + Dropout (ON)  |  Layer 1
+-------------------------------+
      |
      v
+-------------------------------+
| Dense + ReLU + Dropout (ON)  |  Layer 2
+-------------------------------+
      |
      v
Output [batch, output_size]

(Run N times, compute mean + variance)

Usage

# Build model with always-on dropout
model = MCDropout.build(
  input_size: 256,
  hidden_sizes: [128, 64],
  output_size: 10,
  dropout_rate: 0.2
)

# Get predictions with uncertainty
{mean, variance} = MCDropout.predict_with_uncertainty(
  model, params, input, num_samples: 30
)

References

Summary

Types

Options for build/1.

Functions

Build an MLP with dropout that stays active at inference time.

Build a Dense + always-on Dropout layer.

Run N forward passes with dropout and return mean prediction + variance.

Compute predictive entropy from MC Dropout samples.

Types

build_opt()

@type build_opt() ::
  {:activation, atom()}
  | {:dropout_rate, float()}
  | {:hidden_sizes, [pos_integer()]}
  | {:input_size, pos_integer()}
  | {:output_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an MLP with dropout that stays active at inference time.

Uses Axon's training mode to keep dropout active. The key difference from a standard MLP is that dropout is applied at every layer and is intended to remain active during inference for uncertainty estimation.

Options

  • :input_size - Input feature dimension (required)
  • :hidden_sizes - List of hidden layer sizes (default: [256, 128])
  • :output_size - Output dimension (required)
  • :dropout_rate - Dropout probability (default: 0.2)
  • :activation - Activation function (default: :relu)

Returns

An Axon model: [batch, input_size] -> [batch, output_size]

build_mc_layer(input, units, opts \\ [])

@spec build_mc_layer(Axon.t(), pos_integer(), keyword()) :: Axon.t()

Build a Dense + always-on Dropout layer.

This is the building block for MC Dropout networks. Dropout is applied after activation and is kept active even in inference mode by running the model in training mode during prediction.

Parameters

  • input - Axon node
  • units - Number of output units

Options

  • :dropout_rate - Dropout probability (default: 0.2)
  • :activation - Activation function (default: :relu)
  • :name - Layer name prefix (default: "mc_layer")

Returns

An Axon node with shape [batch, units]

predict_with_uncertainty(model, params, input, opts \\ [])

@spec predict_with_uncertainty(Axon.t(), map(), Nx.Tensor.t(), keyword()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Run N forward passes with dropout and return mean prediction + variance.

This is the core MC Dropout inference procedure. By running the model multiple times in training mode (dropout active), each pass uses a different random dropout mask, producing different outputs. The variance across these outputs quantifies the model's uncertainty.

Parameters

  • model - Axon model built with build/1
  • params - Trained model parameters
  • input - Input tensor [batch, input_size]

Options

  • :num_samples - Number of forward passes (default: 20)

Returns

Tuple of {mean, variance} where:

  • mean - Average prediction [batch, output_size]
  • variance - Prediction variance [batch, output_size]

predictive_entropy(mean_probs)

@spec predictive_entropy(Nx.Tensor.t()) :: Nx.Tensor.t()

Compute predictive entropy from MC Dropout samples.

For classification tasks, entropy provides an alternative uncertainty measure that captures total uncertainty (both aleatoric and epistemic).

H[y|x] = -sum(p * log(p)) where p = mean softmax probabilities

Parameters

  • mean_probs - Mean predicted probabilities [batch, num_classes]

Returns

Entropy values [batch]