# `ExBurn.Training`
[🔗](https://github.com/ohhi-vn/ex_burn/blob/main/lib/ex_burn/training.ex#L1)

Training loop implementation for ExBurn models.

Provides a flexible training loop with support for:

- Mini-batch training with real gradient computation
- Multiple optimizers (Adam, SGD with momentum, RMSprop)
- Learning rate scheduling (step, exponential, cosine)
- Gradient clipping (by norm and by value)
- Validation
- Callbacks (logging, early stopping, checkpointing)
- GPU-accelerated gradient computation via Burn

## Usage

    model = ExBurn.Model.compile(axon_model, loss: :cross_entropy, optimizer: :adam)

    opts = [
      epochs: 10,
      batch_size: 32,
      validation_data: val_data,
      lr_schedule: {:cosine, 0.001, 1.0e-5},
      clip_norm: 1.0,
      callbacks: [&ExBurn.Training.LoggingCallback.log/2]
    ]

    trained_model = ExBurn.Training.fit(model, train_data, opts)

# `callback`

```elixir
@type callback() :: (map() -&gt; map())
```

# `dataset`

```elixir
@type dataset() :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

# `lr_schedule`

```elixir
@type lr_schedule() ::
  {:step, float(), pos_integer(), float()}
  | {:exponential, float(), float()}
  | {:cosine, float(), float()}
  | nil
```

# `training_opts`

```elixir
@type training_opts() :: [
  epochs: pos_integer(),
  batch_size: pos_integer(),
  validation_data: dataset() | nil,
  callbacks: [callback()],
  verbose: boolean(),
  lr_schedule: lr_schedule(),
  clip_norm: float() | nil,
  clip_value: float() | nil
]
```

# `data_loader`

```elixir
@spec data_loader(
  dataset(),
  keyword()
) :: Enumerable.t()
```

Creates a data loader that yields mini-batches from a dataset.

# `evaluate`

```elixir
@spec evaluate(ExBurn.Model.t(), dataset()) :: float()
```

Evaluates a model on a dataset.

Returns the average loss over the entire dataset.

# `fit`

```elixir
@spec fit(ExBurn.Model.t(), dataset(), keyword()) :: ExBurn.Model.t()
```

Trains a model on the given dataset.

## Options

  * `:epochs` — Number of training epochs (default: 10)
  * `:batch_size` — Mini-batch size (default: 32)
  * `:validation_data` — Validation dataset as `{inputs, targets}` tuple
  * `:callbacks` — List of callback functions called after each epoch
  * `:verbose` — Print training progress (default: true)
  * `:lr_schedule` — Learning rate schedule (default: nil)
  * `:clip_norm` — Max gradient norm for clipping (default: nil)
  * `:clip_value` — Max absolute gradient value for clipping (default: nil)

## Returns

  The trained `ExBurn.Model` struct with updated parameters.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
