Training Models with ExBurn

Copy Markdown View Source

Overview

ExBurn provides a complete training pipeline: define a model with Axon, compile it with ExBurn.Model.compile/2, and train it with ExBurn.Training.fit/3. The training loop supports multiple optimizers, learning rate schedules, gradient clipping, weight decay, and callbacks.

Defining a Model with Axon

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256, activation: :relu, name: "hidden1")
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(128, activation: :relu, name: "hidden2")
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(10, name: "output")

The nil in the shape represents the batch dimension (variable size).

Compiling a Model

compiled = ExBurn.Model.compile(model,
  loss: :cross_entropy,       # :cross_entropy | :mse | :binary_cross_entropy
  optimizer: :adam,           # :adam | :sgd | :rmsprop
  learning_rate: 0.001,
  device: :gpu,               # :gpu | :cpu
  weight_decay: 1.0e-4        # L2 regularization (default: 0.0)
)

What compile/2 Does

  1. Builds the Axon expression graph via Axon.build/2
  2. Initializes parameters with Glorot/Xavier uniform initialization for weights, zeros for biases
  3. Optionally moves parameters to GPU via BurnBridge.to_gpu/1
  4. Initializes optimizer state (momentum buffers for Adam, velocity for SGD, etc.)
  5. Returns an ExBurn.Model struct ready for training

Inspecting a Model

# Keras/PyTorch-style summary
IO.puts(ExBurn.Model.summary(compiled))
# ╔══════════════════════════════════════════════════════════╗
# ║                   ExBurn Model Summary                  ║
# ╠══════════════════════════════════════════════════════════╣
# ║  Layer                 Type         Output Shape        ║
# ║  hidden1               Dense        [nil, 256]          ║
# ║  dropout_1             Dropout      [nil, 256]          ║
# ║  hidden2               Dense        [nil, 128]          ║
# ║  output                Dense        [nil, 10]           ║
# ╠══════════════════════════════════════════════════════════╣
# ║  Total params:     235,146                               ║
# ║  Trainable params: 235,146                               ║
# ╚══════════════════════════════════════════════════════════╝

# Access components
ExBurn.Model.parameters(compiled)     # parameter map
ExBurn.Model.loss_function(compiled)  # :cross_entropy
ExBurn.Model.optimizer(compiled)      # :adam
ExBurn.Model.weight_decay(compiled)   # 0.0001

Training

trained = ExBurn.Training.fit(compiled, {train_x, train_y},
  epochs: 10,
  batch_size: 32,
  shuffle: true,
  validation_data: {val_x, val_y},
  verbose: true
)

Training Options

OptionTypeDefaultDescription
:epochspos_integer()10Number of training epochs
:batch_sizepos_integer()32Mini-batch size
:shuffleboolean()trueShuffle training data each epoch
:validation_data{tensor, tensor}nilValidation dataset
:callbacks[function()][]Callback functions called after each epoch
:verboseboolean()truePrint training progress
:lr_schedulesee belownilLearning rate schedule
:clip_normfloat()nilMax gradient norm for clipping
:clip_valuefloat()nilMax absolute gradient value
:weight_decayfloat()nilL2 regularization coefficient
:accumulate_gradientspos_integer()1Accumulate N batches before optimizer step
:accuracyboolean()falseCompute classification accuracy
:nesterovboolean()falseNesterov momentum (SGD only)

Learning Rate Schedules

# Step decay: multiply LR by gamma every step_size epochs
lr_schedule: {:step, 0.001, 10, 0.5}

# Exponential decay: LR = base_lr * gamma^epoch
lr_schedule: {:exponential, 0.001, 0.95}

# Cosine annealing: smoothly decay from base_lr to min_lr
lr_schedule: {:cosine, 0.001, 1.0e-5}

Gradient Clipping

# Clip by global norm (prevents exploding gradients)
clip_norm: 1.0

# Clip by absolute value
clip_value: 5.0

# Both can be used together

Gradient Accumulation

Effective when GPU memory limits batch size. Accumulates gradients across N mini-batches before performing one optimizer step:

# Effective batch_size = 32 * 4 = 128
ExBurn.Training.fit(model, data,
  batch_size: 32,
  accumulate_gradients: 4
)

Optimizers

Adam (default)

Adaptive learning rate with momentum. Good default for most tasks.

ExBurn.Model.compile(model, optimizer: :adam, learning_rate: 0.001)
# Internal state: m (1st moment), v (2nd moment), t (timestep)
# beta1=0.9, beta2=0.999, epsilon=1e-8

SGD with Momentum

ExBurn.Model.compile(model, optimizer: :sgd, learning_rate: 0.01)
# momentum=0.9

With Nesterov momentum (often converges faster):

ExBurn.Training.fit(model, data, nesterov: true)

RMSprop

Good for recurrent networks and non-stationary objectives:

ExBurn.Model.compile(model, optimizer: :rmsprop, learning_rate: 0.001)
# decay=0.9, epsilon=1e-8

Callbacks

Callbacks are functions that receive a metrics map after each epoch and return it (possibly modified).

Built-in Callbacks

# Logging
callbacks: [&ExBurn.Training.LoggingCallback.log/1]

# Early stopping (patience=5 epochs, min_delta=1e-4)
callbacks: [ExBurn.Training.EarlyStoppingCallback.wait(5, 1.0e-4)]

# Checkpoint every 5 epochs
callbacks: [ExBurn.Training.CheckpointCallback.every(5, "/checkpoints")]

Custom Callbacks

The metrics map has this structure:

%{
  epoch: 5,
  loss: 0.0234,
  val_loss: 0.0312,       # if validation_data provided
  accuracy: 0.98,          # if accuracy: true
  val_accuracy: 0.95,      # if validation_data + accuracy
  model: %ExBurn.Model{}   # current model state
}

Return Map.put(metrics, :stop_training, true) to halt training early:

custom_callback = fn
  %{loss: loss} when loss < 0.01 ->
    IO.puts("Converged!")
    %{epoch: epoch, loss: loss, stop_training: true}

  metrics ->
    metrics
end

Evaluating a Model

# Returns average loss
loss = ExBurn.Training.evaluate(model, {test_x, test_y})

# Returns {loss, accuracy} tuple
{loss, accuracy} = ExBurn.Training.evaluate(model, {test_x, test_y}, true)

Inference

# Using the model's forward pass (GPU via defn compiler)
{:ok, output} = ExBurn.Model.forward(compiled, input_tensor)

# Using Axon predict (CPU via BinaryBackend)
{:ok, output} = ExBurn.Model.predict(compiled, input_tensor)

Saving and Loading

# Save (compressed Erlang term format)
ExBurn.Model.save(trained, "model.bin")

# Load
{:ok, model} = ExBurn.Model.load(trained, "model.bin")

# Serialize to binary (for network transfer)
binary = ExBurn.Model.serialize_params(trained)
{:ok, params} = ExBurn.Model.deserialize_params(binary)

Freezing Layers (Fine-tuning)

Freeze layers to prevent them from updating during training:

# Freeze specific layers
model = ExBurn.Model.freeze(compiled, ["hidden1"])

# Check if a layer is frozen
ExBurn.Model.frozen?(model, "hidden1")  # true

# Unfreeze
model = ExBurn.Model.unfreeze(model, ["hidden1"])

# Get all frozen layer names
ExBurn.Model.frozen_layers(model)  # #MapSet<["hidden1"]>

Device Management

# Move model to GPU
gpu_model = ExBurn.Model.to_device(compiled, :gpu)

# Move model to CPU
cpu_model = ExBurn.Model.to_device(compiled, :cpu)

# No-op if already on target device
same_model = ExBurn.Model.to_device(cpu_model, :cpu)

Custom Training Loops

For full control, use train_step/3 directly:

{loss, updated_model} = ExBurn.Training.train_step(model, {batch_x, batch_y},
  clip_norm: 1.0,
  grad_method: :numerical_batch
)

Compute gradients separately:

grads = ExBurn.Training.compute_gradients(model, {batch_x, batch_y},
  grad_method: :numerical  # or :numerical_batch
)

Loss Functions

LossUse CaseTarget Format
:cross_entropyMulti-class classificationOne-hot or integer class indices
:mseRegressionContinuous values
:binary_cross_entropyBinary classification0.0 or 1.0