ExBurn.Training (ex_burn v0.1.0)

Copy Markdown View Source

Training loop implementation for ExBurn models.

Provides a flexible training loop with support for:

  • Mini-batch training with real gradient computation
  • Multiple optimizers (Adam, SGD with momentum, RMSprop)
  • Learning rate scheduling (step, exponential, cosine)
  • Gradient clipping (by norm and by value)
  • Validation
  • Callbacks (logging, early stopping, checkpointing)
  • GPU-accelerated gradient computation via Burn

Usage

model = ExBurn.Model.compile(axon_model, loss: :cross_entropy, optimizer: :adam)

opts = [
  epochs: 10,
  batch_size: 32,
  validation_data: val_data,
  lr_schedule: {:cosine, 0.001, 1.0e-5},
  clip_norm: 1.0,
  callbacks: [&ExBurn.Training.LoggingCallback.log/2]
]

trained_model = ExBurn.Training.fit(model, train_data, opts)

Summary

Functions

Creates a data loader that yields mini-batches from a dataset.

Evaluates a model on a dataset.

Trains a model on the given dataset.

Types

callback()

@type callback() :: (map() -> map())

dataset()

@type dataset() :: {Nx.Tensor.t(), Nx.Tensor.t()}

lr_schedule()

@type lr_schedule() ::
  {:step, float(), pos_integer(), float()}
  | {:exponential, float(), float()}
  | {:cosine, float(), float()}
  | nil

training_opts()

@type training_opts() :: [
  epochs: pos_integer(),
  batch_size: pos_integer(),
  validation_data: dataset() | nil,
  callbacks: [callback()],
  verbose: boolean(),
  lr_schedule: lr_schedule(),
  clip_norm: float() | nil,
  clip_value: float() | nil
]

Functions

data_loader(arg, opts \\ [])

@spec data_loader(
  dataset(),
  keyword()
) :: Enumerable.t()

Creates a data loader that yields mini-batches from a dataset.

evaluate(model, arg)

@spec evaluate(ExBurn.Model.t(), dataset()) :: float()

Evaluates a model on a dataset.

Returns the average loss over the entire dataset.

fit(model, arg, opts \\ [])

@spec fit(ExBurn.Model.t(), dataset(), keyword()) :: ExBurn.Model.t()

Trains a model on the given dataset.

Options

  • :epochs — Number of training epochs (default: 10)
  • :batch_size — Mini-batch size (default: 32)
  • :validation_data — Validation dataset as {inputs, targets} tuple
  • :callbacks — List of callback functions called after each epoch
  • :verbose — Print training progress (default: true)
  • :lr_schedule — Learning rate schedule (default: nil)
  • :clip_norm — Max gradient norm for clipping (default: nil)
  • :clip_value — Max absolute gradient value for clipping (default: nil)

Returns

The trained ExBurn.Model struct with updated parameters.