Edifice.Energy.EBM (Edifice v0.2.0)

Copy Markdown View Source

Energy-Based Model (EBM).

Implements an energy function network that assigns scalar energy values to inputs. The model learns an energy landscape where low-energy regions correspond to high-probability data configurations. Training uses contrastive divergence to push down energy on real data and push up energy on negative samples generated via Langevin dynamics MCMC.

Architecture

Input [batch, input_size]
      |
      v
+-----------------------------+
| Energy Network (MLP):       |
|   Dense -> Act -> Dense ... |
+-----------------------------+
      |
      v
+-----------------------------+
| Scalar Output:              |
|   Dense -> 1                |
+-----------------------------+
      |
      v
Energy [batch, 1]
(lower = more likely)

Training Loop

1. Compute E(x_real) on real data
2. Sample x_neg via Langevin dynamics from current model
3. Compute E(x_neg) on negative samples
4. Loss = E(x_real) - E(x_neg) (+ regularization)
5. Update parameters to minimize loss

Langevin Dynamics Sampling

Starting from noise, iteratively refine samples by following the negative gradient of the energy function with added noise:

x_{t+1} = x_t - step_size * grad_E(x_t) + sqrt(2 * step_size) * noise

Usage

# Build energy function
model = EBM.build(input_size: 784, hidden_sizes: [256, 128])

# Training
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({32, 784}, :f32), Axon.ModelState.empty())

# Compute energies
real_energy = predict_fn.(params, %{"input" => real_data})
neg_samples = EBM.langevin_sample(predict_fn, params, 784, steps: 60)
neg_energy = predict_fn.(params, %{"input" => neg_samples})

# Contrastive divergence loss
loss = EBM.contrastive_divergence_loss(real_energy, neg_energy)

References

  • "A Tutorial on Energy-Based Learning" (LeCun et al., 2006)
  • "Implicit Generation and Modeling with Energy-Based Models" (Du & Mordatch, 2019)

Summary

Types

Options for build/1.

Functions

Build an Energy-Based Model.

Build the energy function network from an existing Axon input.

Contrastive divergence loss for training energy-based models.

Compute the energy of a batch of inputs using a compiled model.

Generate samples via Langevin dynamics MCMC.

Types

build_opt()

@type build_opt() :: {:input_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an Energy-Based Model.

Constructs an MLP that maps inputs to scalar energy values. The network uses the specified hidden layers with activations, culminating in a single linear output neuron (the energy).

Options

  • :input_size - Input feature dimension (required)
  • :hidden_sizes - List of hidden layer sizes (default: [256, 128])
  • :activation - Activation function (default: :silu)
  • :dropout - Dropout rate (default: 0.0)
  • :spectral_norm - Apply spectral normalization for Lipschitz constraint (default: false)

Returns

An Axon model. Input: {batch, input_size}, Output: {batch, 1} (energy).

build_energy_fn(input, opts \\ [])

@spec build_energy_fn(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the energy function network from an existing Axon input.

This is the core builder that can be composed into larger architectures.

Parameters

  • input - Axon input node
  • opts - Options (same as build/1 minus :input_size)

Returns

An Axon node outputting scalar energy {batch, 1}.

contrastive_divergence_loss(real_energy, neg_energy, opts \\ [])

@spec contrastive_divergence_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  Nx.Tensor.t()

Contrastive divergence loss for training energy-based models.

The loss pushes down energy on real data and pushes up energy on negative (generated) samples:

L = E(x_real) - E(x_neg) + alpha * (E(x_real)^2 + E(x_neg)^2)

The regularization term prevents energies from diverging to extreme values.

Parameters

  • real_energy - Energy of real data samples {batch, 1}
  • neg_energy - Energy of negative (generated) samples {batch, 1}
  • opts - Options

Options

  • :reg_alpha - Regularization strength on energy magnitudes (default: 0.01)

Returns

Scalar loss tensor.

energy(predict_fn, params, input)

@spec energy(function(), map(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the energy of a batch of inputs using a compiled model.

Convenience wrapper that handles the input map construction.

Parameters

  • predict_fn - Compiled energy function
  • params - Model parameters
  • input - Input tensor {batch, input_size}

Returns

Energy tensor {batch, 1}.

langevin_sample(predict_fn, params, input_size, opts \\ [])

@spec langevin_sample(function(), map(), pos_integer(), keyword()) :: Nx.Tensor.t()

Generate samples via Langevin dynamics MCMC.

Starting from random noise, iteratively refines samples by following the negative gradient of the energy function with injected Gaussian noise:

x_{t+1} = x_t - (step_size / 2) * grad_E(x_t) + sqrt(step_size) * N(0, noise_scale)

The noise term ensures exploration and, in the limit, samples from the Boltzmann distribution p(x) proportional to exp(-E(x)).

Parameters

  • predict_fn - Compiled energy function (params, input) -> energy
  • params - Model parameters
  • input_size - Dimension of samples to generate
  • opts - Options

Options

  • :batch_size - Number of samples to generate (default: 32)
  • :steps - Number of Langevin steps (default: 60)
  • :step_size - Langevin step size (default: 10.0)
  • :noise_scale - Scale of injected noise (default: 0.005)
  • :init - Initial samples, or nil for uniform noise (default: nil)
  • :clamp - Clamp range {min, max} or nil (default: nil)

Returns

Generated samples tensor {batch_size, input_size}.