ExBurn.Training (ex_burn v0.3.0)

Copy Markdown View Source

Training loop implementation for ExBurn models.

Provides a flexible training loop with support for:

  • Mini-batch training with gradient computation
  • Multiple optimizers (Adam, SGD with momentum / Nesterov, RMSprop)
  • Learning rate scheduling (step, exponential, cosine)
  • Gradient clipping (by norm and by value)
  • Weight decay (L2 regularization)
  • Gradient accumulation for effective larger batch sizes
  • Batch shuffling each epoch
  • Validation with partial batch handling
  • Training metrics tracking (loss, accuracy)
  • Progress reporting with ETA and throughput
  • Callbacks (logging, early stopping, checkpointing)
  • Public train_step/3 and compute_gradients/3 for custom loops

Usage

model = ExBurn.Model.compile(axon_model, loss: :cross_entropy, optimizer: :adam)

opts = [
  epochs: 10,
  batch_size: 32,
  shuffle: true,
  validation_data: val_data,
  lr_schedule: {:cosine, 0.001, 1.0e-5},
  clip_norm: 1.0,
  weight_decay: 1.0e-4,
  accumulate_gradients: 4,
  accuracy: true,
  nesterov: true,
  callbacks: [&ExBurn.Training.LoggingCallback.log/1]
]

trained_model = ExBurn.Training.fit(model, train_data, opts)

Summary

Functions

Computes gradients for a given mini-batch.

Creates a data loader that yields mini-batches from a dataset.

Evaluates a model on a dataset.

Trains a model on the given dataset.

Profiles a single training step, returning detailed timing for each phase.

Performs a single training step: forward + backward + optimizer update.

Types

callback()

@type callback() :: (map() -> map())

dataset()

@type dataset() :: {Nx.Tensor.t(), Nx.Tensor.t()}

lr_schedule()

@type lr_schedule() ::
  {:step, float(), pos_integer(), float()}
  | {:exponential, float(), float()}
  | {:cosine, float(), float()}
  | nil

model()

@type model() :: ExBurn.Model.t()

training_opts()

@type training_opts() :: [
  epochs: pos_integer(),
  batch_size: pos_integer(),
  shuffle: boolean(),
  validation_data: dataset() | nil,
  callbacks: [callback()],
  verbose: boolean(),
  lr_schedule: lr_schedule(),
  clip_norm: float() | nil,
  clip_value: float() | nil,
  weight_decay: float() | nil,
  accumulate_gradients: pos_integer(),
  accuracy: boolean(),
  nesterov: boolean(),
  warmup: pos_integer()
]

Functions

compute_gradients(model, arg, opts \\ [])

@spec compute_gradients(model(), dataset(), keyword()) :: map()

Computes gradients for a given mini-batch.

Supports multiple gradient computation methods via the :grad_method option:

  • :numerical — Central finite differences (default, slow but general)
  • :numerical_batch — Numerical gradients computed on the full batch at once (more efficient, fewer forward passes)

Parameters

  • model — The current model state
  • batch — A {inputs, targets} tuple
  • opts — Options list

Options

  • :grad_method — Gradient computation method (default: :numerical)
  • :epsilon — Finite difference step size (default: 1.0e-5)

Returns

A map of {param_key => gradient_tensor}.

data_loader(arg, opts \\ [])

@spec data_loader(
  dataset(),
  keyword()
) :: Enumerable.t()

Creates a data loader that yields mini-batches from a dataset.

evaluate(model, arg, track_accuracy \\ false)

@spec evaluate(model(), dataset(), boolean()) :: float() | {float(), float() | nil}

Evaluates a model on a dataset.

Returns the average loss over the entire dataset. When track_accuracy is true, returns {loss, accuracy} where accuracy is a float or nil if the loss function is not cross_entropy.

fit(model, arg, opts \\ [])

@spec fit(model(), dataset(), keyword()) :: model()

Trains a model on the given dataset.

Options

  • :epochs — Number of training epochs (default: 10)
  • :batch_size — Mini-batch size (default: 32)
  • :shuffle — Shuffle training data each epoch (default: true)
  • :validation_data — Validation dataset as {inputs, targets} tuple
  • :callbacks — List of callback functions called after each epoch
  • :verbose — Print training progress (default: true)
  • :lr_schedule — Learning rate schedule (default: nil)
  • :clip_norm — Max gradient norm for clipping (default: nil)
  • :clip_value — Max absolute gradient value for clipping (default: nil)
  • :weight_decay — L2 regularization coefficient (default: nil)
  • :accumulate_gradients — Number of mini-batches to accumulate before an optimizer step, effectively multiplying batch size (default: 1)
  • :accuracy — Compute and report classification accuracy (default: false)
  • :nesterov — Use Nesterov momentum for SGD optimizer (default: false)

Returns

The trained ExBurn.Model struct with updated parameters.

profile_step(model, arg, opts \\ [])

@spec profile_step(model(), dataset(), keyword()) :: map()

Profiles a single training step, returning detailed timing for each phase.

Useful for identifying bottlenecks in the training pipeline.

Parameters

  • model — The current model state
  • batch — A {inputs, targets} tuple
  • opts — Options (same as train_step/3)

Returns

A map with :forward_ms, :backward_ms, :optimizer_ms, :total_ms.

Example

profile = ExBurn.Training.profile_step(model, batch)
IO.puts("Forward: #{profile.forward_ms}ms, Backward: #{profile.backward_ms}ms")

train_step(model, arg, opts \\ [])

@spec train_step(model(), dataset(), keyword()) :: {float(), model()}

Performs a single training step: forward + backward + optimizer update.

This is a public API for custom training loops. It processes a single mini-batch and returns the updated model along with the loss value.

Parameters

  • model — The current model state
  • batch — A {inputs, targets} tuple for one mini-batch
  • opts — Options (same as fit/3 for clip_norm, clip_value, weight_decay)

Returns

{loss, updated_model} where loss is a float.