Training loop implementation for ExBurn models.
Provides a flexible training loop with support for:
- Mini-batch training with gradient computation
- Multiple optimizers (Adam, SGD with momentum / Nesterov, RMSprop)
- Learning rate scheduling (step, exponential, cosine)
- Gradient clipping (by norm and by value)
- Weight decay (L2 regularization)
- Gradient accumulation for effective larger batch sizes
- Batch shuffling each epoch
- Validation with partial batch handling
- Training metrics tracking (loss, accuracy)
- Progress reporting with ETA and throughput
- Callbacks (logging, early stopping, checkpointing)
- Public
train_step/3andcompute_gradients/3for custom loops
Usage
model = ExBurn.Model.compile(axon_model, loss: :cross_entropy, optimizer: :adam)
opts = [
epochs: 10,
batch_size: 32,
shuffle: true,
validation_data: val_data,
lr_schedule: {:cosine, 0.001, 1.0e-5},
clip_norm: 1.0,
weight_decay: 1.0e-4,
accumulate_gradients: 4,
accuracy: true,
nesterov: true,
callbacks: [&ExBurn.Training.LoggingCallback.log/1]
]
trained_model = ExBurn.Training.fit(model, train_data, opts)
Summary
Functions
Computes gradients for a given mini-batch.
Creates a data loader that yields mini-batches from a dataset.
Evaluates a model on a dataset.
Trains a model on the given dataset.
Profiles a single training step, returning detailed timing for each phase.
Performs a single training step: forward + backward + optimizer update.
Types
@type dataset() :: {Nx.Tensor.t(), Nx.Tensor.t()}
@type model() :: ExBurn.Model.t()
@type training_opts() :: [ epochs: pos_integer(), batch_size: pos_integer(), shuffle: boolean(), validation_data: dataset() | nil, callbacks: [callback()], verbose: boolean(), lr_schedule: lr_schedule(), clip_norm: float() | nil, clip_value: float() | nil, weight_decay: float() | nil, accumulate_gradients: pos_integer(), accuracy: boolean(), nesterov: boolean(), warmup: pos_integer() ]
Functions
Computes gradients for a given mini-batch.
Supports multiple gradient computation methods via the :grad_method option:
:numerical— Central finite differences (default, slow but general):numerical_batch— Numerical gradients computed on the full batch at once (more efficient, fewer forward passes)
Parameters
model— The current model statebatch— A{inputs, targets}tupleopts— Options list
Options
:grad_method— Gradient computation method (default::numerical):epsilon— Finite difference step size (default: 1.0e-5)
Returns
A map of {param_key => gradient_tensor}.
@spec data_loader( dataset(), keyword() ) :: Enumerable.t()
Creates a data loader that yields mini-batches from a dataset.
Evaluates a model on a dataset.
Returns the average loss over the entire dataset.
When track_accuracy is true, returns {loss, accuracy} where accuracy
is a float or nil if the loss function is not cross_entropy.
Trains a model on the given dataset.
Options
:epochs— Number of training epochs (default: 10):batch_size— Mini-batch size (default: 32):shuffle— Shuffle training data each epoch (default: true):validation_data— Validation dataset as{inputs, targets}tuple:callbacks— List of callback functions called after each epoch:verbose— Print training progress (default: true):lr_schedule— Learning rate schedule (default: nil):clip_norm— Max gradient norm for clipping (default: nil):clip_value— Max absolute gradient value for clipping (default: nil):weight_decay— L2 regularization coefficient (default: nil):accumulate_gradients— Number of mini-batches to accumulate before an optimizer step, effectively multiplying batch size (default: 1):accuracy— Compute and report classification accuracy (default: false):nesterov— Use Nesterov momentum for SGD optimizer (default: false)
Returns
The trained ExBurn.Model struct with updated parameters.
Profiles a single training step, returning detailed timing for each phase.
Useful for identifying bottlenecks in the training pipeline.
Parameters
model— The current model statebatch— A{inputs, targets}tupleopts— Options (same astrain_step/3)
Returns
A map with :forward_ms, :backward_ms, :optimizer_ms, :total_ms.
Example
profile = ExBurn.Training.profile_step(model, batch)
IO.puts("Forward: #{profile.forward_ms}ms, Backward: #{profile.backward_ms}ms")
Performs a single training step: forward + backward + optimizer update.
This is a public API for custom training loops. It processes a single mini-batch and returns the updated model along with the loss value.
Parameters
model— The current model statebatch— A{inputs, targets}tuple for one mini-batchopts— Options (same asfit/3for clip_norm, clip_value, weight_decay)
Returns
{loss, updated_model} where loss is a float.