Bayesian Neural Network layers with weight uncertainty.
Instead of point-estimate weights, each weight is a probability distribution parameterized by (mu, rho) where the actual weight is sampled as:
W = mu + softplus(rho) * epsilon, epsilon ~ N(0, 1)This provides:
- Uncertainty estimation - Variance in predictions indicates confidence
- Regularization - KL divergence from prior acts as a learned regularizer
- Robustness - Multiple weight samples reduce overconfidence
Training with ELBO
The network is trained by maximizing the Evidence Lower BOund:
ELBO = E_q[log p(D|W)] - beta * KL(q(W) || p(W))where:
- q(W) = N(mu, softplus(rho)^2) is the learned weight posterior
- p(W) = N(0, 1) is the prior (standard normal)
- beta controls the regularization strength (typically 1/num_batches)
Architecture
Input [batch, input_size]
|
v
+-------------------------------+
| Bayesian Dense (sample W) |
| W = mu + softplus(rho) * eps |
+-------------------------------+
|
v
+-------------------------------+
| Activation (ReLU) |
+-------------------------------+
| (repeat for each layer)
v
Output [batch, output_size]Usage
# Build a Bayesian neural network
model = Bayesian.build(
input_size: 256,
hidden_sizes: [128, 64],
output_size: 10
)
# Training: use ELBO loss
loss = Bayesian.elbo_loss(predictions, targets, kl_cost, beta: 1/num_batches)References
- Blundell et al., "Weight Uncertainty in Neural Networks" (2015)
- https://arxiv.org/abs/1505.05424
Summary
Functions
Build a Bayesian dense layer using the reparameterization trick.
Build a Bayesian Neural Network.
Compute the Evidence Lower BOund (ELBO) loss.
Compute KL divergence between weight posterior q(W) and prior p(W).
Types
@type build_opt() :: {:activation, atom()} | {:hidden_sizes, [pos_integer()]} | {:input_size, pos_integer()} | {:output_size, pos_integer()}
Options for build/1.
Functions
@spec bayesian_dense(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a Bayesian dense layer using the reparameterization trick.
Instead of learning fixed weights W, learns mu and rho parameters where sigma = softplus(rho) = log(1 + exp(rho)).
During the forward pass:
- Sample epsilon ~ N(0, 1)
- Compute W = mu + sigma * epsilon
- Output = input * W + bias
Parameters
input- Axon nodeunits- Number of output units
Options
:name- Layer name prefix (default: "bayesian_dense")
Returns
An Axon node with shape [batch, units]
Build a Bayesian Neural Network.
Each dense layer uses weight distributions instead of point estimates. During the forward pass, weights are sampled from the learned posterior.
Options
:input_size- Input feature dimension (required):hidden_sizes- List of hidden layer sizes (default: [256, 128]):output_size- Output dimension (required):activation- Activation function (default: :relu):prior_sigma- Standard deviation of the weight prior (default: 1.0)
Returns
An Axon model: [batch, input_size] -> [batch, output_size]
Note
The model uses the reparameterization trick internally. During training, different samples produce different outputs (stochastic). For deterministic inference, use the mean weights (mu) directly by setting epsilon to zero in the params.
@spec elbo_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute the Evidence Lower BOund (ELBO) loss.
ELBO = reconstruction_loss + beta * KL_divergence
The reconstruction loss measures how well the model fits the data, while KL divergence regularizes the weight posterior toward the prior.
Parameters
predictions- Model predictions[batch, output_size]targets- Ground truth targets[batch, output_size]kl_divergence- KL cost fromkl_cost/3
Options
:beta- KL weight, typically1 / num_batches(default: 1.0):loss_fn- Reconstruction loss::mseor:cross_entropy(default: :mse)
Returns
Scalar ELBO loss (to be minimized)
@spec kl_cost(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute KL divergence between weight posterior q(W) and prior p(W).
For Gaussian posterior N(mu, sigma^2) and Gaussian prior N(0, prior_sigma^2):
KL = sum(log(prior_sigma/sigma) + (sigma^2 + mu^2)/(2*prior_sigma^2) - 0.5)Parameters
mu- Weight means[...]rho- Weight log-variance parameters[...]
Options
:prior_sigma- Standard deviation of the prior (default: 1.0)
Returns
Scalar KL divergence cost