# Training Models with ExBurn

## Overview

ExBurn provides a complete training pipeline: define a model with Axon, compile it with `ExBurn.Model.compile/2`, and train it with `ExBurn.Training.fit/3`. The training loop supports multiple optimizers, learning rate schedules, gradient clipping, weight decay, and callbacks.

## Defining a Model with Axon

```elixir
model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256, activation: :relu, name: "hidden1")
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(128, activation: :relu, name: "hidden2")
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(10, name: "output")
```

The `nil` in the shape represents the batch dimension (variable size).

## Compiling a Model

```elixir
compiled = ExBurn.Model.compile(model,
  loss: :cross_entropy,       # :cross_entropy | :mse | :binary_cross_entropy
  optimizer: :adam,           # :adam | :sgd | :rmsprop
  learning_rate: 0.001,
  device: :gpu,               # :gpu | :cpu
  weight_decay: 1.0e-4        # L2 regularization (default: 0.0)
)
```

### What `compile/2` Does

1. Builds the Axon expression graph via `Axon.build/2`
2. Initializes parameters with **Glorot/Xavier uniform initialization** for weights, zeros for biases
3. Optionally moves parameters to GPU via `BurnBridge.to_gpu/1`
4. Initializes optimizer state (momentum buffers for Adam, velocity for SGD, etc.)
5. Returns an `ExBurn.Model` struct ready for training

### Inspecting a Model

```elixir
# Keras/PyTorch-style summary
IO.puts(ExBurn.Model.summary(compiled))
# ╔══════════════════════════════════════════════════════════╗
# ║                   ExBurn Model Summary                  ║
# ╠══════════════════════════════════════════════════════════╣
# ║  Layer                 Type         Output Shape        ║
# ║  hidden1               Dense        [nil, 256]          ║
# ║  dropout_1             Dropout      [nil, 256]          ║
# ║  hidden2               Dense        [nil, 128]          ║
# ║  output                Dense        [nil, 10]           ║
# ╠══════════════════════════════════════════════════════════╣
# ║  Total params:     235,146                               ║
# ║  Trainable params: 235,146                               ║
# ╚══════════════════════════════════════════════════════════╝

# Access components
ExBurn.Model.parameters(compiled)     # parameter map
ExBurn.Model.loss_function(compiled)  # :cross_entropy
ExBurn.Model.optimizer(compiled)      # :adam
ExBurn.Model.weight_decay(compiled)   # 0.0001
```

## Training

```elixir
trained = ExBurn.Training.fit(compiled, {train_x, train_y},
  epochs: 10,
  batch_size: 32,
  shuffle: true,
  validation_data: {val_x, val_y},
  verbose: true
)
```

### Training Options

| Option | Type | Default | Description |
|---|---|---|---|
| `:epochs` | `pos_integer()` | `10` | Number of training epochs |
| `:batch_size` | `pos_integer()` | `32` | Mini-batch size |
| `:shuffle` | `boolean()` | `true` | Shuffle training data each epoch |
| `:validation_data` | `{tensor, tensor}` | `nil` | Validation dataset |
| `:callbacks` | `[function()]` | `[]` | Callback functions called after each epoch |
| `:verbose` | `boolean()` | `true` | Print training progress |
| `:lr_schedule` | see below | `nil` | Learning rate schedule |
| `:clip_norm` | `float()` | `nil` | Max gradient norm for clipping |
| `:clip_value` | `float()` | `nil` | Max absolute gradient value |
| `:weight_decay` | `float()` | `nil` | L2 regularization coefficient |
| `:accumulate_gradients` | `pos_integer()` | `1` | Accumulate N batches before optimizer step |
| `:accuracy` | `boolean()` | `false` | Compute classification accuracy |
| `:nesterov` | `boolean()` | `false` | Nesterov momentum (SGD only) |

### Learning Rate Schedules

```elixir
# Step decay: multiply LR by gamma every step_size epochs
lr_schedule: {:step, 0.001, 10, 0.5}

# Exponential decay: LR = base_lr * gamma^epoch
lr_schedule: {:exponential, 0.001, 0.95}

# Cosine annealing: smoothly decay from base_lr to min_lr
lr_schedule: {:cosine, 0.001, 1.0e-5}
```

### Gradient Clipping

```elixir
# Clip by global norm (prevents exploding gradients)
clip_norm: 1.0

# Clip by absolute value
clip_value: 5.0

# Both can be used together
```

### Gradient Accumulation

Effective when GPU memory limits batch size. Accumulates gradients across N mini-batches before performing one optimizer step:

```elixir
# Effective batch_size = 32 * 4 = 128
ExBurn.Training.fit(model, data,
  batch_size: 32,
  accumulate_gradients: 4
)
```

## Optimizers

### Adam (default)

Adaptive learning rate with momentum. Good default for most tasks.

```elixir
ExBurn.Model.compile(model, optimizer: :adam, learning_rate: 0.001)
# Internal state: m (1st moment), v (2nd moment), t (timestep)
# beta1=0.9, beta2=0.999, epsilon=1e-8
```

### SGD with Momentum

```elixir
ExBurn.Model.compile(model, optimizer: :sgd, learning_rate: 0.01)
# momentum=0.9
```

With Nesterov momentum (often converges faster):

```elixir
ExBurn.Training.fit(model, data, nesterov: true)
```

### RMSprop

Good for recurrent networks and non-stationary objectives:

```elixir
ExBurn.Model.compile(model, optimizer: :rmsprop, learning_rate: 0.001)
# decay=0.9, epsilon=1e-8
```

## Callbacks

Callbacks are functions that receive a metrics map after each epoch and return it (possibly modified).

### Built-in Callbacks

```elixir
# Logging
callbacks: [&ExBurn.Training.LoggingCallback.log/1]

# Early stopping (patience=5 epochs, min_delta=1e-4)
callbacks: [ExBurn.Training.EarlyStoppingCallback.wait(5, 1.0e-4)]

# Checkpoint every 5 epochs
callbacks: [ExBurn.Training.CheckpointCallback.every(5, "/checkpoints")]
```

### Custom Callbacks

The metrics map has this structure:

```elixir
%{
  epoch: 5,
  loss: 0.0234,
  val_loss: 0.0312,       # if validation_data provided
  accuracy: 0.98,          # if accuracy: true
  val_accuracy: 0.95,      # if validation_data + accuracy
  model: %ExBurn.Model{}   # current model state
}
```

Return `Map.put(metrics, :stop_training, true)` to halt training early:

```elixir
custom_callback = fn
  %{loss: loss} when loss < 0.01 ->
    IO.puts("Converged!")
    %{epoch: epoch, loss: loss, stop_training: true}

  metrics ->
    metrics
end
```

## Evaluating a Model

```elixir
# Returns average loss
loss = ExBurn.Training.evaluate(model, {test_x, test_y})

# Returns {loss, accuracy} tuple
{loss, accuracy} = ExBurn.Training.evaluate(model, {test_x, test_y}, true)
```

## Inference

```elixir
# Using the model's forward pass (GPU via defn compiler)
{:ok, output} = ExBurn.Model.forward(compiled, input_tensor)

# Using Axon predict (CPU via BinaryBackend)
{:ok, output} = ExBurn.Model.predict(compiled, input_tensor)
```

## Saving and Loading

```elixir
# Save (compressed Erlang term format)
ExBurn.Model.save(trained, "model.bin")

# Load
{:ok, model} = ExBurn.Model.load(trained, "model.bin")

# Serialize to binary (for network transfer)
binary = ExBurn.Model.serialize_params(trained)
{:ok, params} = ExBurn.Model.deserialize_params(binary)
```

## Freezing Layers (Fine-tuning)

Freeze layers to prevent them from updating during training:

```elixir
# Freeze specific layers
model = ExBurn.Model.freeze(compiled, ["hidden1"])

# Check if a layer is frozen
ExBurn.Model.frozen?(model, "hidden1")  # true

# Unfreeze
model = ExBurn.Model.unfreeze(model, ["hidden1"])

# Get all frozen layer names
ExBurn.Model.frozen_layers(model)  # #MapSet<["hidden1"]>
```

## Device Management

```elixir
# Move model to GPU
gpu_model = ExBurn.Model.to_device(compiled, :gpu)

# Move model to CPU
cpu_model = ExBurn.Model.to_device(compiled, :cpu)

# No-op if already on target device
same_model = ExBurn.Model.to_device(cpu_model, :cpu)
```

## Custom Training Loops

For full control, use `train_step/3` directly:

```elixir
{loss, updated_model} = ExBurn.Training.train_step(model, {batch_x, batch_y},
  clip_norm: 1.0,
  grad_method: :numerical_batch
)
```

Compute gradients separately:

```lixir
grads = ExBurn.Training.compute_gradients(model, {batch_x, batch_y},
  grad_method: :numerical  # or :numerical_batch
)
```

## Loss Functions

| Loss | Use Case | Target Format |
|---|---|---|
| `:cross_entropy` | Multi-class classification | One-hot or integer class indices |
| `:mse` | Regression | Continuous values |
| `:binary_cross_entropy` | Binary classification | 0.0 or 1.0 |
