Nasty.Statistics.Neural.Trainer (Nasty v0.3.0)
View SourceTraining utilities for neural models using Axon.Loop.
Provides a high-level interface for training neural networks with:
- Multiple optimizer support (Adam, SGD, AdamW)
- Learning rate scheduling
- Early stopping
- Checkpointing
- Metric tracking
- Gradient clipping
- Regularization
Example
opts = [
epochs: 10,
batch_size: 32,
optimizer: :adam,
learning_rate: 0.001,
early_stopping: [patience: 3, min_delta: 0.001]
]
{:ok, trained_model} = Trainer.train(model, train_data, valid_data, opts)Training Loop
The training loop follows this structure:
- Forward pass: Compute predictions from inputs
- Loss computation: Calculate loss between predictions and targets
- Backward pass: Compute gradients via backpropagation
- Optimization: Update model parameters
- Validation: Evaluate on validation set
- Checkpointing: Save best model based on validation metrics
Summary
Functions
Adds checkpointing to a training loop.
Adds early stopping to a training loop.
Creates a training loop with custom configuration.
Evaluates a trained model on test data.
Returns an optimizer function.
Trains a neural model using the provided training and validation data.
Creates default training configuration.
Types
Functions
Adds checkpointing to a training loop.
Parameters
loop- Training loopopts- Checkpointing options
Returns
Loop with checkpointing configured.
Adds early stopping to a training loop.
Parameters
loop- Training loopopts- Early stopping options
Returns
Loop with early stopping configured.
Creates a training loop with custom configuration.
Parameters
model- Axon modelconfig- Training configuration
Returns
Axon.Loop configured for training.
Evaluates a trained model on test data.
Parameters
model- Axon modelstate- Trained model state (parameters)test_data- Test datasetopts- Evaluation options
Returns
{:ok, metrics}- Evaluation metrics{:error, reason}- Evaluation error
Returns an optimizer function.
Parameters
optimizer- Optimizer typeopts- Optimizer options
Returns
Optimizer function.
@spec train( model_fn :: (-> Axon.t()), train_data :: training_data(), valid_data :: validation_data() | nil, opts :: keyword() ) :: {:ok, map()} | {:error, term()}
Trains a neural model using the provided training and validation data.
Parameters
model_fn- Function that builds the Axon modeltrain_data- Training dataset (list of {inputs, targets} tuples)valid_data- Validation dataset (optional)opts- Training options
Options
:epochs- Number of training epochs (default: 10):batch_size- Batch size for training (default: 32):optimizer- Optimizer to use::adam,:sgd,:adamw(default::adam):learning_rate- Learning rate (default: 0.001):loss- Loss function::cross_entropy,:mean_squared_error,:crf(default::cross_entropy):metrics- Additional metrics to track (default: [:accuracy]):early_stopping- Early stopping config (default: nil):checkpoint_dir- Directory to save checkpoints (default: nil):gradient_clip- Gradient clipping value (default: nil):dropout- Dropout rate (default: 0.0):l2_regularization- L2 regularization lambda (default: 0.0):lr_schedule- Learning rate schedule (default: nil)
Returns
{:ok, trained_state}- Trained model state with parameters{:error, reason}- Training error
Creates default training configuration.
Parameters
opts- Optional overrides
Returns
Map with training configuration.