Training Optimization Guide

Copy Markdown View Source

Table of Contents

  1. Understanding Gradient Computation
  2. Choosing an Optimizer
  3. Learning Rate Strategies
  4. Gradient Clipping
  5. Weight Decay
  6. Gradient Accumulation
  7. Batch Size Selection
  8. Memory Optimization
  9. Common Problems and Solutions
  10. Performance Benchmarks

Understanding Gradient Computation

Current Limitation: Numerical Gradients

ExBurn v0.1.0 uses numerical differentiation (finite differences) to compute gradients. This is the main performance bottleneck.

Central differences:  L/w  (L(w + ε) - L(w - ε)) / 2ε
One-sided:            L/w  (L(w + ε) - L(w)) / ε

Impact: For a model with N scalar parameters, central differences requires 2N forward passes per mini-batch. A 100K-parameter model needs 200K forward passes per batch.

Choosing a Gradient Method

# Default: central differences (more accurate, slower)
grads = ExBurn.Training.compute_gradients(model, {x, y}, grad_method: :numerical)

# Faster: one-sided differences (less accurate, ~2x faster)
grads = ExBurn.Training.compute_gradients(model, {x, y}, grad_method: :numerical_batch)
MethodForward PassesError OrderWhen to Use
:numerical2NO(ε²)Small models, high accuracy needed
:numerical_batchN+1O(ε)Larger models, speed matters more

When Autodiff Arrives (v0.3.0)

Burn's Autodiff backend will compute exact gradients in a single backward pass, regardless of parameter count. This is a game-changer:

Numerical (v0.1.0):  200K forward passes for 100K params
Autodiff (v0.3.0):   1 backward pass for any model size

Recommendation: For now, keep models small (< 50K params) for training. Use larger models only for inference.

Choosing an Optimizer

Adam (Default)

Best general-purpose optimizer. Adapts learning rates per-parameter.

ExBurn.Model.compile(model, optimizer: :adam, learning_rate: 0.001)
# beta1=0.9, beta2=0.999, epsilon=1e-8

When to use: Default choice for most tasks. Works well with default hyperparameters.

Tips:

  • learning_rate: 0.001 is a good starting point
  • Reduce to 0.0001 if training is unstable
  • Increase to 0.01 if convergence is very slow

SGD with Momentum

Can achieve better generalization than Adam with proper tuning.

ExBurn.Model.compile(model, optimizer: :sgd, learning_rate: 0.01)
# momentum=0.9

When to use: When you need maximum generalization and have time to tune.

Tips:

  • Requires higher learning rate than Adam (typically 0.01–0.1)
  • Use Nesterov momentum for faster convergence:
    ExBurn.Training.fit(model, data, nesterov: true)
  • Combine with cosine annealing LR schedule for best results

RMSprop

Good for recurrent networks and non-stationary objectives.

ExBurn.Model.compile(model, optimizer: :rmsprop, learning_rate: 0.001)
# decay=0.9, epsilon=1e-8

When to use: RNNs, LSTMs, or when Adam diverges.

Optimizer Comparison

OptimizerConvergence SpeedGeneralizationTuning EffortMemory
AdamFastGoodLow2x params (m + v)
SGD + MomentumMediumBestHigh1x params (velocity)
RMSpropMediumGoodMedium1x params (cache)

Learning Rate Strategies

Fixed Learning Rate

# No schedule — use constant learning rate
ExBurn.Model.compile(model, learning_rate: 0.001)

Step Decay

Reduce LR by a factor every N epochs. Good for long training runs.

# Halve the learning rate every 10 epochs
lr_schedule: {:step, 0.001, 10, 0.5}

Exponential Decay

Smooth decay. Good for medium-length training.

# Multiply LR by 0.95 each epoch
lr_schedule: {:exponential, 0.001, 0.95}

Cosine Annealing

Smoothly decay from base_lr to min_lr following a cosine curve. Often gives the best results.

# Decay from 0.001 to 0.00001 over the training run
lr_schedule: {:cosine, 0.001, 1.0e-5}

Learning Rate Schedule Comparison

LR

0.001  
            Step (sudden drops)
      
       
          
0.0001          
               
                 
                   
0.00001  Cosine (smooth)
 Epochs

Tips

  • Start with Adam + cosine annealing for best results
  • If loss oscillates, reduce the base learning rate
  • If convergence is too slow, increase the base learning rate
  • Use warmup (planned) for large batch sizes

Gradient Clipping

Prevents exploding gradients, which cause NaN loss.

Clip by Norm

Scales all gradients so their total norm doesn't exceed a threshold:

# If ||gradients||_2 > 1.0, scale them down
clip_norm: 1.0

When to use: Always enable for recurrent networks. Recommended for deep networks.

Clip by Value

Clips each gradient element to a range:

# Clip each gradient to [-5.0, 5.0]
clip_value: 5.0

When to use: As a safety net alongside norm clipping.

Tips

  • clip_norm: 1.0 is a good default
  • If you see NaN loss, enable clipping immediately
  • Clipping doesn't prevent vanishing gradients — use residual connections for that

Weight Decay

L2 regularization that penalizes large weights, improving generalization:

ExBurn.Model.compile(model, weight_decay: 1.0e-4)

This adds weight_decay * param to each gradient before the optimizer step.

Tips

  • 1.0e-4 is a good default for most tasks
  • 1.0e-5 for small datasets (less regularization)
  • 1.0e-3 for large models that overfit
  • Don't use with AdamW (not yet implemented) — with standard Adam, weight decay interacts with the adaptive learning rate

Gradient Accumulation

Simulates a larger batch size by accumulating gradients across multiple mini-batches:

# Effective batch size = 32 * 4 = 128
ExBurn.Training.fit(model, data,
  batch_size: 32,
  accumulate_gradients: 4
)

When to Use

  • GPU memory limits your batch size
  • You want the stability of large batches but can't fit them in memory
  • Training on mobile devices with limited RAM

Tips

  • Increase learning rate proportionally to the accumulation factor (e.g., 4x accumulation → 2x LR)
  • Batch normalization (when available) will still see the small mini-batch statistics

Batch Size Selection

Batch SizeProsCons
8–16Better generalization, less memoryNoisy gradients, slower training
32–64Good defaultBalanced
128–256Faster training, stable gradientsMay generalize worse, more memory
512+Very stable gradientsOften worse generalization, high memory

Tips

  • Start with 32 and increase if you have memory headroom
  • If you increase batch size, increase learning rate proportionally
  • Use gradient accumulation to simulate large batches on memory-constrained devices

Memory Optimization

On Desktop (CUDA/Metal)

# Use f16 for 2x memory reduction
# (convert parameters to f16 before training)

# Use gradient accumulation to reduce per-batch memory
accumulate_gradients: 4

On Mobile (iOS/Android)

# Keep models small (< 10M params)
# Use CPU for training (GPU autodiff is memory-intensive)
ExBurn.Model.compile(model, device: :cpu)

# Free intermediate tensors explicitly
ExBurn.Tensor.free(intermediate_tensor)

Memory-Saving Tips

  1. Reduce batch size — the single biggest lever
  2. Use gradient accumulation — same effective batch, less memory
  3. Free tensors explicitly — don't wait for GC
  4. Use f16 precision — halves memory for tensors
  5. Avoid storing all intermediate activations — use gradient checkpointing (planned)

Common Problems and Solutions

Loss is NaN

Causes: Exploding gradients, too high learning rate, numerical instability

Solutions:

# 1. Enable gradient clipping
clip_norm: 1.0

# 2. Reduce learning rate
learning_rate: 0.0001

# 3. Use :numerical_batch gradient method (more stable)
grad_method: :numerical_batch

Loss Doesn't Decrease

Causes: Too low learning rate, bad initialization, wrong loss function

Solutions:

# 1. Increase learning rate
learning_rate: 0.01

# 2. Check loss function matches task
#    Classification → :cross_entropy
#    Regression → :mse
#    Binary → :binary_cross_entropy

# 3. Verify data preprocessing (normalization, etc.)

Loss Oscillates

Causes: Learning rate too high, batch size too small

Solutions:

# 1. Reduce learning rate
learning_rate: 0.0005

# 2. Increase batch size or use gradient accumulation
accumulate_gradients: 4

# 3. Use learning rate schedule
lr_schedule: {:cosine, 0.001, 1.0e-6}

Overfitting

Causes: Model too complex, not enough data, no regularization

Solutions:

# 1. Add weight decay
weight_decay: 1.0e-3

# 2. Add dropout in the Axon model
|> Axon.dropout(rate: 0.5)

# 3. Freeze early layers
model = ExBurn.Model.freeze(model, ["hidden1"])

# 4. Use early stopping
callbacks: [ExBurn.Training.EarlyStoppingCallback.wait(5)]

Training is Very Slow

Causes: Numerical gradients on large model, too many epochs

Solutions:

# 1. Use faster gradient method
grad_method: :numerical_batch

# 2. Reduce model size
# 3. Use fewer epochs with early stopping
callbacks: [ExBurn.Training.EarlyStoppingCallback.wait(3)]

# 4. Increase batch size (fewer optimizer steps)
batch_size: 128

Performance Benchmarks

Approximate training times per epoch on synthetic data (will vary by hardware):

Model SizeParamsBatchMethodTime/Epoch
Tiny MLP1K32:numerical~2s
Small MLP10K32:numerical~15s
Small MLP10K32:numerical_batch~8s
Medium MLP100K32:numerical~3min
Medium MLP100K32:numerical_batch~1.5min

Key takeaway: With numerical gradients, training time scales linearly with parameter count. Keep models under 50K parameters for interactive training, or switch to inference-only for larger models until autodiff arrives in v0.3.0.

For Quick Experiments

compiled = ExBurn.Model.compile(model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001
)

ExBurn.Training.fit(compiled, data,
  epochs: 10,
  batch_size: 32,
  verbose: true
)

For Best Results

compiled = ExBurn.Model.compile(model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001,
  weight_decay: 1.0e-4
)

ExBurn.Training.fit(compiled, data,
  epochs: 50,
  batch_size: 64,
  shuffle: true,
  validation_data: val_data,
  lr_schedule: {:cosine, 0.001, 1.0e-6},
  clip_norm: 1.0,
  accuracy: true,
  callbacks: [
    &ExBurn.Training.LoggingCallback.log/1,
    ExBurn.Training.EarlyStoppingCallback.wait(10, 1.0e-5),
    ExBurn.Training.CheckpointCallback.every(10, "/checkpoints")
  ]
)

For Memory-Constrained Devices

compiled = ExBurn.Model.compile(model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.0005,
  device: :cpu
)

ExBurn.Training.fit(compiled, data,
  epochs: 20,
  batch_size: 16,
  accumulate_gradients: 4,
  clip_norm: 1.0
)