MC Dropout for uncertainty estimation (Gal & Ghahramani, 2016).
Standard dropout is disabled at inference time. MC Dropout keeps dropout active during inference and runs multiple forward passes to estimate prediction uncertainty. This provides a practical Bayesian approximation without modifying the training procedure.
How It Works
- Train a standard network with dropout (nothing special)
- At inference time, keep dropout ON
- Run N forward passes with different dropout masks
- Mean of outputs = prediction, Variance = uncertainty
Interpretation
- Low variance: Model is confident (consistent predictions across dropout masks)
- High variance: Model is uncertain (different subnetworks disagree)
- Out-of-distribution: Typically shows high variance
Architecture
Input [batch, input_size]
|
v
+-------------------------------+
| Dense + ReLU + Dropout (ON) | Layer 1
+-------------------------------+
|
v
+-------------------------------+
| Dense + ReLU + Dropout (ON) | Layer 2
+-------------------------------+
|
v
Output [batch, output_size]
(Run N times, compute mean + variance)Usage
# Build model with always-on dropout
model = MCDropout.build(
input_size: 256,
hidden_sizes: [128, 64],
output_size: 10,
dropout_rate: 0.2
)
# Get predictions with uncertainty
{mean, variance} = MCDropout.predict_with_uncertainty(
model, params, input, num_samples: 30
)References
- Gal & Ghahramani, "Dropout as a Bayesian Approximation" (2016)
- https://arxiv.org/abs/1506.02142
Summary
Functions
Build an MLP with dropout that stays active at inference time.
Build a Dense + always-on Dropout layer.
Run N forward passes with dropout and return mean prediction + variance.
Compute predictive entropy from MC Dropout samples.
Types
@type build_opt() :: {:activation, atom()} | {:dropout_rate, float()} | {:hidden_sizes, [pos_integer()]} | {:input_size, pos_integer()} | {:output_size, pos_integer()}
Options for build/1.
Functions
Build an MLP with dropout that stays active at inference time.
Uses Axon's training mode to keep dropout active. The key difference from a standard MLP is that dropout is applied at every layer and is intended to remain active during inference for uncertainty estimation.
Options
:input_size- Input feature dimension (required):hidden_sizes- List of hidden layer sizes (default: [256, 128]):output_size- Output dimension (required):dropout_rate- Dropout probability (default: 0.2):activation- Activation function (default: :relu)
Returns
An Axon model: [batch, input_size] -> [batch, output_size]
@spec build_mc_layer(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a Dense + always-on Dropout layer.
This is the building block for MC Dropout networks. Dropout is applied after activation and is kept active even in inference mode by running the model in training mode during prediction.
Parameters
input- Axon nodeunits- Number of output units
Options
:dropout_rate- Dropout probability (default: 0.2):activation- Activation function (default: :relu):name- Layer name prefix (default: "mc_layer")
Returns
An Axon node with shape [batch, units]
@spec predict_with_uncertainty(Axon.t(), map(), Nx.Tensor.t(), keyword()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Run N forward passes with dropout and return mean prediction + variance.
This is the core MC Dropout inference procedure. By running the model multiple times in training mode (dropout active), each pass uses a different random dropout mask, producing different outputs. The variance across these outputs quantifies the model's uncertainty.
Parameters
model- Axon model built withbuild/1params- Trained model parametersinput- Input tensor[batch, input_size]
Options
:num_samples- Number of forward passes (default: 20)
Returns
Tuple of {mean, variance} where:
mean- Average prediction[batch, output_size]variance- Prediction variance[batch, output_size]
@spec predictive_entropy(Nx.Tensor.t()) :: Nx.Tensor.t()
Compute predictive entropy from MC Dropout samples.
For classification tasks, entropy provides an alternative uncertainty measure that captures total uncertainty (both aleatoric and epistemic).
H[y|x] = -sum(p * log(p)) where p = mean softmax probabilities
Parameters
mean_probs- Mean predicted probabilities[batch, num_classes]
Returns
Entropy values [batch]