# `ExBurn.Training`
[🔗](https://github.com/ohhi-vn/ex_burn/blob/main/lib/ex_burn/training.ex#L1)

Training loop implementation for ExBurn models.

Provides a flexible training loop with support for:

- Mini-batch training with gradient computation
- Multiple optimizers (Adam, SGD with momentum / Nesterov, RMSprop)
- Learning rate scheduling (step, exponential, cosine)
- Gradient clipping (by norm and by value)
- Weight decay (L2 regularization)
- Gradient accumulation for effective larger batch sizes
- Batch shuffling each epoch
- Validation with partial batch handling
- Training metrics tracking (loss, accuracy)
- Progress reporting with ETA and throughput
- Callbacks (logging, early stopping, checkpointing)
- Public `train_step/3` and `compute_gradients/3` for custom loops

## Usage

    model = ExBurn.Model.compile(axon_model, loss: :cross_entropy, optimizer: :adam)

    opts = [
      epochs: 10,
      batch_size: 32,
      shuffle: true,
      validation_data: val_data,
      lr_schedule: {:cosine, 0.001, 1.0e-5},
      clip_norm: 1.0,
      weight_decay: 1.0e-4,
      accumulate_gradients: 4,
      accuracy: true,
      nesterov: true,
      callbacks: [&ExBurn.Training.LoggingCallback.log/1]
    ]

    trained_model = ExBurn.Training.fit(model, train_data, opts)

# `callback`

```elixir
@type callback() :: (map() -&gt; map())
```

# `dataset`

```elixir
@type dataset() :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

# `lr_schedule`

```elixir
@type lr_schedule() ::
  {:step, float(), pos_integer(), float()}
  | {:exponential, float(), float()}
  | {:cosine, float(), float()}
  | nil
```

# `model`

```elixir
@type model() :: ExBurn.Model.t()
```

# `training_opts`

```elixir
@type training_opts() :: [
  epochs: pos_integer(),
  batch_size: pos_integer(),
  shuffle: boolean(),
  validation_data: dataset() | nil,
  callbacks: [callback()],
  verbose: boolean(),
  lr_schedule: lr_schedule(),
  clip_norm: float() | nil,
  clip_value: float() | nil,
  weight_decay: float() | nil,
  accumulate_gradients: pos_integer(),
  accuracy: boolean(),
  nesterov: boolean(),
  warmup: pos_integer()
]
```

# `compute_gradients`

```elixir
@spec compute_gradients(model(), dataset(), keyword()) :: map()
```

Computes gradients for a given mini-batch.

Supports multiple gradient computation methods via the `:grad_method` option:

  * `:numerical` — Central finite differences (default, slow but general)
  * `:numerical_batch` — Numerical gradients computed on the full batch at once
    (more efficient, fewer forward passes)

## Parameters

  * `model` — The current model state
  * `batch` — A `{inputs, targets}` tuple
  * `opts` — Options list

## Options

  * `:grad_method` — Gradient computation method (default: `:numerical`)
  * `:epsilon` — Finite difference step size (default: 1.0e-5)

## Returns

  A map of `{param_key => gradient_tensor}`.

# `data_loader`

```elixir
@spec data_loader(
  dataset(),
  keyword()
) :: Enumerable.t()
```

Creates a data loader that yields mini-batches from a dataset.

# `evaluate`

```elixir
@spec evaluate(model(), dataset(), boolean()) :: float() | {float(), float() | nil}
```

Evaluates a model on a dataset.

Returns the average loss over the entire dataset.
When `track_accuracy` is true, returns `{loss, accuracy}` where accuracy
is a float or `nil` if the loss function is not cross_entropy.

# `fit`

```elixir
@spec fit(model(), dataset(), keyword()) :: model()
```

Trains a model on the given dataset.

## Options

  * `:epochs` — Number of training epochs (default: 10)
  * `:batch_size` — Mini-batch size (default: 32)
  * `:shuffle` — Shuffle training data each epoch (default: true)
  * `:validation_data` — Validation dataset as `{inputs, targets}` tuple
  * `:callbacks` — List of callback functions called after each epoch
  * `:verbose` — Print training progress (default: true)
  * `:lr_schedule` — Learning rate schedule (default: nil)
  * `:clip_norm` — Max gradient norm for clipping (default: nil)
  * `:clip_value` — Max absolute gradient value for clipping (default: nil)
  * `:weight_decay` — L2 regularization coefficient (default: nil)
  * `:accumulate_gradients` — Number of mini-batches to accumulate before
    an optimizer step, effectively multiplying batch size (default: 1)
  * `:accuracy` — Compute and report classification accuracy (default: false)
  * `:nesterov` — Use Nesterov momentum for SGD optimizer (default: false)

## Returns

  The trained `ExBurn.Model` struct with updated parameters.

# `profile_step`

```elixir
@spec profile_step(model(), dataset(), keyword()) :: map()
```

Profiles a single training step, returning detailed timing for each phase.

Useful for identifying bottlenecks in the training pipeline.

## Parameters

  * `model` — The current model state
  * `batch` — A `{inputs, targets}` tuple
  * `opts` — Options (same as `train_step/3`)

## Returns

  A map with `:forward_ms`, `:backward_ms`, `:optimizer_ms`, `:total_ms`.

## Example

    profile = ExBurn.Training.profile_step(model, batch)
    IO.puts("Forward: #{profile.forward_ms}ms, Backward: #{profile.backward_ms}ms")

# `train_step`

```elixir
@spec train_step(model(), dataset(), keyword()) :: {float(), model()}
```

Performs a single training step: forward + backward + optimizer update.

This is a public API for custom training loops. It processes a single
mini-batch and returns the updated model along with the loss value.

## Parameters

  * `model` — The current model state
  * `batch` — A `{inputs, targets}` tuple for one mini-batch
  * `opts` — Options (same as `fit/3` for clip_norm, clip_value, weight_decay)

## Returns

  `{loss, updated_model}` where loss is a float.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
