Nasty.Statistics.Neural.Trainer (Nasty v0.3.0)

View Source

Training utilities for neural models using Axon.Loop.

Provides a high-level interface for training neural networks with:

  • Multiple optimizer support (Adam, SGD, AdamW)
  • Learning rate scheduling
  • Early stopping
  • Checkpointing
  • Metric tracking
  • Gradient clipping
  • Regularization

Example

opts = [
  epochs: 10,
  batch_size: 32,
  optimizer: :adam,
  learning_rate: 0.001,
  early_stopping: [patience: 3, min_delta: 0.001]
]

{:ok, trained_model} = Trainer.train(model, train_data, valid_data, opts)

Training Loop

The training loop follows this structure:

  1. Forward pass: Compute predictions from inputs
  2. Loss computation: Calculate loss between predictions and targets
  3. Backward pass: Compute gradients via backpropagation
  4. Optimization: Update model parameters
  5. Validation: Evaluate on validation set
  6. Checkpointing: Save best model based on validation metrics

Summary

Functions

Adds checkpointing to a training loop.

Adds early stopping to a training loop.

Creates a training loop with custom configuration.

Evaluates a trained model on test data.

Returns an optimizer function.

Trains a neural model using the provided training and validation data.

Creates default training configuration.

Types

training_data()

@type training_data() :: [{inputs :: map(), targets :: map()}]

validation_data()

@type validation_data() :: [{inputs :: map(), targets :: map()}]

Functions

add_checkpointing(loop, opts \\ [])

Adds checkpointing to a training loop.

Parameters

  • loop - Training loop
  • opts - Checkpointing options

Returns

Loop with checkpointing configured.

add_early_stopping(loop, opts \\ [])

Adds early stopping to a training loop.

Parameters

  • loop - Training loop
  • opts - Early stopping options

Returns

Loop with early stopping configured.

create_training_loop(model, config)

Creates a training loop with custom configuration.

Parameters

  • model - Axon model
  • config - Training configuration

Returns

Axon.Loop configured for training.

evaluate(model, state, test_data, opts \\ [])

@spec evaluate(Axon.t(), map(), list(), keyword()) :: {:ok, map()} | {:error, term()}

Evaluates a trained model on test data.

Parameters

  • model - Axon model
  • state - Trained model state (parameters)
  • test_data - Test dataset
  • opts - Evaluation options

Returns

  • {:ok, metrics} - Evaluation metrics
  • {:error, reason} - Evaluation error

get_optimizer(optimizer, opts \\ [])

Returns an optimizer function.

Parameters

  • optimizer - Optimizer type
  • opts - Optimizer options

Returns

Optimizer function.

train(model_fn, train_data, valid_data \\ nil, opts \\ [])

@spec train(
  model_fn :: (-> Axon.t()),
  train_data :: training_data(),
  valid_data :: validation_data() | nil,
  opts :: keyword()
) :: {:ok, map()} | {:error, term()}

Trains a neural model using the provided training and validation data.

Parameters

  • model_fn - Function that builds the Axon model
  • train_data - Training dataset (list of {inputs, targets} tuples)
  • valid_data - Validation dataset (optional)
  • opts - Training options

Options

  • :epochs - Number of training epochs (default: 10)
  • :batch_size - Batch size for training (default: 32)
  • :optimizer - Optimizer to use: :adam, :sgd, :adamw (default: :adam)
  • :learning_rate - Learning rate (default: 0.001)
  • :loss - Loss function: :cross_entropy, :mean_squared_error, :crf (default: :cross_entropy)
  • :metrics - Additional metrics to track (default: [:accuracy])
  • :early_stopping - Early stopping config (default: nil)
  • :checkpoint_dir - Directory to save checkpoints (default: nil)
  • :gradient_clip - Gradient clipping value (default: nil)
  • :dropout - Dropout rate (default: 0.0)
  • :l2_regularization - L2 regularization lambda (default: 0.0)
  • :lr_schedule - Learning rate schedule (default: nil)

Returns

  • {:ok, trained_state} - Trained model state with parameters
  • {:error, reason} - Training error

training_config(opts \\ [])

Creates default training configuration.

Parameters

  • opts - Optional overrides

Returns

Map with training configuration.