# `Dala.ML.Burn`
[🔗](https://github.com/manhvu/dala/blob/main/lib/dala/ml/burn.ex#L1)

Dala integration for the [Burn](https://burn.dev) deep learning framework.

ExBurn provides a `Nx.Backend` implementation that delegates tensor operations
to Burn via Rust NIFs, enabling GPU-accelerated ML/DL on mobile and desktop.

## Architecture

```
Axon model
   ↓
Nx.Defn graph
   ↓
ExBurn.Defn.Compiler (Nx.Defn.Compiler behaviour)
   ↓
ExBurn.Backend (Nx.Backend behaviour)
   ↓
ExBurn.Nif (Rustler NIF) ←→ ExCubecl (GPU buffers, kernels, pipelines)
   ↓
Burn Autodiff<CubeCL> (Rust)
   ↓
CubeCL kernels
   ↓
Metal (iOS) / Vulkan (Android) / CUDA → GPU
```

## Quick Start

    # Set ExBurn as the default Nx backend
    Dala.ML.Burn.configure!()

    # Create and manipulate tensors
    t = Nx.tensor([1.0, 2.0, 3.0])
    Nx.add(t, t) |> Nx.to_list()

    # Define a model with Axon
    model =
      Axon.input("input", shape: {nil, 784})
      |> Axon.dense(256, activation: :relu)
      |> Axon.dropout(rate: 0.2)
      |> Axon.dense(10)

    # Compile for training
    compiled = Dala.ML.Burn.compile(model,
      loss: :cross_entropy,
      optimizer: :adam,
      learning_rate: 0.001
    )

    # Train
    Dala.ML.Burn.fit(compiled, {train_x, train_y},
      epochs: 10,
      batch_size: 32
    )

## Platform GPU Backends

| Platform | Backend | Status |
|----------|---------|--------|
| iOS      | Metal   | ✅     |
| Android  | Vulkan  | ✅     |
| macOS    | Metal   | ✅     |
| Linux    | Vulkan  | ✅     |
| NVIDIA   | CUDA    | ✅     |

## Integration with Dala.ML

This module complements the existing Dala ML backends:

- `Dala.ML.EMLX` — MLX backend for Apple Silicon (iOS recommended)
- `Dala.ML.CoreML` — iOS-native CoreML (Neural Engine)
- `Dala.ML.ONNX` — Cross-platform ONNX Runtime
- `Dala.ML.Burn` — Burn framework via ExBurn (this module)

Use `Dala.ML.available_backends/0` to see all available backends,
and `Dala.ML.Burn.available?/0` specifically for Burn support.

## Training on Mobile — Caveats

Burn's Autodiff backend is memory-intensive. On iOS/Android with limited RAM:
- **Fine-tuning** small models (< 10M parameters) is feasible on modern devices
- **Full training** of large models is not recommended on mobile
- **Inference** is the primary use case for mobile deployment
- Minimum recommended: 4GB RAM, A12+ chip (iOS) / Snapdragon 700+ (Android)

# `abs`

```elixir
@spec abs(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Absolute value of a Burn tensor.

# `add`

```elixir
@spec add(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Adds two Burn tensors.

# `available?`

```elixir
@spec available?() :: boolean()
```

Checks whether ExBurn is available (loaded and functional).

# `available_backends`

```elixir
@spec available_backends() :: [atom()]
```

Returns a list of available GPU backends on this system.

# `benchmark`

```elixir
@spec benchmark(ExBurn.Model.t(), Nx.Tensor.t(), keyword()) :: map()
```

Benchmarks the model's forward pass on the given input.

## Options

* `:warmup` — Number of warmup runs (default: 3)
* `:runs` — Number of benchmarked runs (default: 10)

Returns a map with `:avg_ms`, `:min_ms`, `:max_ms`, `:median_ms`, `:std_ms`.

# `buffer`

```elixir
@spec buffer(list(), [non_neg_integer()], atom()) :: ExCubecl.buffer_ref()
```

Creates a GPU buffer via ExCubecl from a list of values.

# `buffer!`

```elixir
@spec buffer!(list(), [non_neg_integer()], atom()) :: ExCubecl.buffer_ref()
```

Creates a GPU buffer via ExCubecl, raising on error.

# `buffer_shape`

```elixir
@spec buffer_shape(ExCubecl.buffer_ref()) :: [non_neg_integer()]
```

Returns the shape of an ExCubecl buffer.

# `buffer_size`

```elixir
@spec buffer_size(ExCubecl.buffer_ref()) :: non_neg_integer()
```

Returns the byte size of an ExCubecl buffer.

# `clone`

```elixir
@spec clone(ExBurn.Model.t()) :: ExBurn.Model.t()
```

Creates a deep copy of the model with identical parameters and configuration.

# `compile`

```elixir
@spec compile(
  Axon.ModelState.t(),
  keyword()
) :: ExBurn.Model.t()
```

Compiles an Axon model for training with the ExBurn backend.

## Options

* `:loss` — Loss function: `:cross_entropy`, `:mse`, `:binary_cross_entropy` (default: `:cross_entropy`)
* `:optimizer` — Optimizer: `:adam`, `:sgd`, `:rmsprop` (default: `:adam`)
* `:learning_rate` — Learning rate (default: 0.001)
* `:device` — Device: `:cpu` or `:gpu` (default: auto-detected)
* `:weight_decay` — L2 regularization coefficient (default: 0.0)

# `compute_gradients`

```elixir
@spec compute_gradients(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  map()
```

Computes gradients for a given mini-batch.

## Options

* `:grad_method` — `:numerical` (central differences) or `:numerical_batch` (one-sided, faster)
* `:epsilon` — Finite difference step size (default: 1.0e-5)

# `compute_loss`

```elixir
@spec compute_loss(ExBurn.Model.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, term()}
```

Computes the loss between predictions and targets.

# `configure!`

```elixir
@spec configure!(keyword()) :: :ok
```

Configures ExBurn for the current platform with Dala-specific defaults.

## Options

* `:device` — Override device (`:cpu` or `:gpu`). Auto-detected by default.
* `:backend` — Override GPU backend (`:metal`, `:vulkan`, `:cuda`). Auto-detected.

# `cross_entropy`

```elixir
@spec cross_entropy(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Cross-entropy loss between predictions and targets.

# `cuda_available?`

```elixir
@spec cuda_available?() :: boolean()
```

Checks whether an NVIDIA CUDA GPU is available.

# `data_loader`

```elixir
@spec data_loader({Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: Enumerable.t()
```

Creates a data loader that yields mini-batches from a dataset.

## Options

* `:batch_size` — Mini-batch size (default: 32)
* `:shuffle` — Shuffle data each iteration (default: true)

# `default_device`

```elixir
@spec default_device() :: :cpu | :gpu
```

Returns the default device for tensor operations.

# `deserialize_params`

```elixir
@spec deserialize_params(binary()) :: {:ok, map()} | {:error, String.t()}
```

Deserializes model parameters from a binary.

# `device_info`

```elixir
@spec device_info() :: map()
```

Returns a map with device information including GPU availability,
backend name, and available backends.

# `device_name`

```elixir
@spec device_name() :: String.t()
```

Returns the name of the active compute device (e.g., "CUDA (NVIDIA GPU)").

# `device_summary`

```elixir
@spec device_summary() :: String.t()
```

Returns a human-readable summary of the GPU device.

# `div`

```elixir
@spec div(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Divides two Burn tensors element-wise.

# `dropout`

```elixir
@spec dropout(ExBurn.Tensor.t(), float(), boolean()) :: ExBurn.Tensor.t()
```

Dropout (identity during inference).

# `enable_defn_compiler!`

```elixir
@spec enable_defn_compiler!() :: :ok
```

Enables the ExBurn defn compiler for GPU-accelerated `Nx.Defn` expressions.

After calling this, all `defn` functions will be compiled through
`ExBurn.Defn.Compiler` and executed on the GPU via Burn.

## Example

    Dala.ML.Burn.enable_defn_compiler!

    defmodule MyMath do
      import Nx.Defn

      defn add_and_scale(x, y, scale) do
        x |> Nx.add(y) |> Nx.multiply(scale)
      end
    end

    # Runs on GPU via Burn
    MyMath.add_and_scale(Nx.tensor([1.0]), Nx.tensor([2.0]), Nx.tensor(3.0))

# `error`

```elixir
@spec error(keyword()) :: ExBurn.Error.t()
```

Creates an `ExBurn.Error` struct (non-raising).

## Example

    Dala.ML.Burn.error(op: :forward, reason: "shape mismatch")

# `error_from_tuple`

```elixir
@spec error_from_tuple({:error, String.t()}, keyword()) :: ExBurn.Error.t()
```

Wraps an error tuple in an `ExBurn.Error`.

## Example

    Dala.ML.Burn.error_from_tuple({:error, "failed"}, op: :predict)

# `error_to_log_string`

```elixir
@spec error_to_log_string(ExBurn.Error.t()) :: String.t()
```

Converts an `ExBurn.Error` to a log string.

# `evaluate`

```elixir
@spec evaluate(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  float() | {float(), float() | nil}
```

Evaluates a model on a dataset.

Returns the average loss, or `{loss, accuracy}` when `track_accuracy: true`.

# `exp`

```elixir
@spec exp(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Exponential of a Burn tensor.

# `export`

```elixir
@spec export(ExBurn.Model.t(), Path.t(), keyword()) :: :ok | {:error, String.t()}
```

Exports model parameters to a file.

## Options

* `:format` — `:elixir_terms` (default, compressed) or `:json` (human-readable)

# `fit`

```elixir
@spec fit(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  ExBurn.Model.t()
```

Trains a model on the given dataset.

## Options

* `:epochs` — Number of training epochs (default: 10)
* `:batch_size` — Mini-batch size (default: 32)
* `:shuffle` — Shuffle training data each epoch (default: true)
* `:validation_data` — Validation dataset as `{inputs, targets}` tuple
* `:callbacks` — List of callback functions called after each epoch
* `:verbose` — Print training progress (default: true)
* `:lr_schedule` — Learning rate schedule (default: nil)
* `:clip_norm` — Max gradient norm for clipping (default: nil)
* `:clip_value` — Max absolute gradient value for clipping (default: nil)
* `:weight_decay` — L2 regularization coefficient (default: nil)
* `:accumulate_gradients` — Gradient accumulation steps (default: 1)
* `:accuracy` — Track and report classification accuracy (default: false)
* `:nesterov` — Use Nesterov momentum for SGD (default: false)

# `format_error`

```elixir
@spec format_error(ExBurn.Error.t()) :: String.t()
```

Formats an `ExBurn.Error` for logging or display.

# `forward`

```elixir
@spec forward(ExBurn.Model.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, term()}
```

Runs a forward pass through the model using the ExBurn GPU defn compiler.

This is the GPU-accelerated path. Requires the model to be compiled.

Returns `{:ok, output_tensor}` or `{:error, reason}`.

# `forward_pattern`

```elixir
@spec forward_pattern(ExBurn.Model.t()) :: %{
  output_shape: tuple() | nil,
  output_type: atom()
}
```

Returns the output shape and type information from the Axon model.

# `free`

```elixir
@spec free(ExBurn.Tensor.t()) :: :ok
```

Frees a Burn tensor's underlying GPU/CPU memory.

# `freeze`

```elixir
@spec freeze(ExBurn.Model.t(), [atom() | String.t()]) :: ExBurn.Model.t()
```

Freezes the specified layers so their parameters are not updated during training.

# `from_binary`

```elixir
@spec from_binary(binary(), [non_neg_integer()], atom()) ::
  {:ok, ExBurn.Tensor.t()} | {:error, term()}
```

Creates a Burn tensor from raw binary data.

# `from_nx`

```elixir
@spec from_nx(Nx.Tensor.t()) :: {:ok, ExBurn.Tensor.t()} | {:error, term()}
```

Converts an Nx tensor to a Burn tensor.

# `from_nx_batch`

```elixir
@spec from_nx_batch([Nx.Tensor.t()]) :: {:ok, [ExBurn.Tensor.t()]} | {:error, term()}
```

Batch converts a list of Nx tensors to Burn tensors.

# `frozen?`

```elixir
@spec frozen?(ExBurn.Model.t(), atom() | String.t()) :: boolean()
```

Checks whether a layer is frozen.

# `frozen_layers`

```elixir
@spec frozen_layers(ExBurn.Model.t()) :: MapSet.t()
```

Returns the set of frozen layer names.

# `gpu?`

```elixir
@spec gpu?() :: boolean()
```

Checks whether a GPU device is available for Burn operations.

# `gpu_available?`

```elixir
@spec gpu_available?() :: boolean()
```

Checks whether a GPU is available via Burn bridge.

# `gpu_memory_info`

```elixir
@spec gpu_memory_info() :: {:ok, map()} | {:error, term()}
```

Returns GPU memory info as `%{total: bytes, used: bytes, free: bytes}`.

# `import_params`

```elixir
@spec import_params(ExBurn.Model.t(), Path.t(), keyword()) ::
  {:ok, ExBurn.Model.t()} | {:error, String.t()}
```

Imports model parameters from a file saved with `export/3`.

## Options

* `:format` — `:elixir_terms` (default) or `:json`

# `info`

```elixir
@spec info(ExBurn.Model.t()) :: map()
```

Returns a map with model information (param count, layer count, device, memory estimate).

# `layer_norm`

```elixir
@spec layer_norm(ExBurn.Tensor.t(), non_neg_integer(), float()) :: ExBurn.Tensor.t()
```

Layer normalization.

# `load`

```elixir
@spec load(ExBurn.Model.t(), Path.t()) :: {:ok, ExBurn.Model.t()} | {:error, term()}
```

Loads model parameters from a file.

# `log`

```elixir
@spec log(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Natural logarithm of a Burn tensor.

# `matmul`

```elixir
@spec matmul(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Matrix multiplication of two Burn tensors.

# `mean`

```elixir
@spec mean(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Mean of all elements in a Burn tensor.

# `mse`

```elixir
@spec mse(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Mean squared error between predictions and targets.

# `mul`

```elixir
@spec mul(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Multiplies two Burn tensors element-wise.

# `neg`

```elixir
@spec neg(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Negates a Burn tensor.

# `new_model`

```elixir
@spec new_model(keyword()) :: ExBurn.Model.t()
```

Creates a new empty model struct. Useful for incremental model building.

# `nif_function_count`

```elixir
@spec nif_function_count() :: non_neg_integer()
```

Returns the number of NIF functions registered by the Rust library.
Useful for debugging NIF loading issues.

# `nif_loaded?`

```elixir
@spec nif_loaded?() :: boolean()
```

Checks whether the NIF library is loaded and responds to calls.

# `ones`

```elixir
@spec ones([non_neg_integer()], atom()) :: ExBurn.Tensor.t()
```

Creates a tensor filled with ones via Burn.

# `parameters`

```elixir
@spec parameters(ExBurn.Model.t()) :: map()
```

Returns the current model parameters.

# `predict`

```elixir
@spec predict(ExBurn.Model.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, term()}
```

Runs a forward pass through the model using Axon's default backend.

For GPU execution, use `forward/2` which uses the ExBurn defn compiler.

Returns `{:ok, output_tensor}` or `{:error, reason}`.

# `profile_step`

```elixir
@spec profile_step(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  map()
```

Profiles a single training step, returning detailed timing for each phase.

Returns a map with `:forward_ms`, `:backward_ms`, `:optimizer_ms`, `:total_ms`.

# `quantize`

```elixir
@spec quantize(ExBurn.Model.t(), :f16 | :bf16) :: ExBurn.Model.t()
```

Quantizes model parameters to a lower precision type (`:f16` or `:bf16`).

Useful for reducing model size and speeding up inference on devices
with limited compute.

# `rand`

```elixir
@spec rand([non_neg_integer()], atom(), float(), float()) :: ExBurn.Tensor.t()
```

Creates a random tensor with uniform distribution via Burn.

# `read_buffer`

```elixir
@spec read_buffer(ExCubecl.buffer_ref()) :: binary()
```

Reads data from an ExCubecl buffer.

# `relu`

```elixir
@spec relu(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

ReLU activation of a Burn tensor.

# `reshape`

```elixir
@spec reshape(ExBurn.Tensor.t(), [non_neg_integer()]) :: ExBurn.Tensor.t()
```

Reshapes a Burn tensor.

# `save`

```elixir
@spec save(ExBurn.Model.t(), Path.t()) :: :ok | {:error, term()}
```

Saves the model parameters to a file using compressed Erlang term format.

# `serialize_params`

```elixir
@spec serialize_params(ExBurn.Model.t()) :: binary()
```

Serializes model parameters to a binary for network transfer or storage.

# `sigmoid`

```elixir
@spec sigmoid(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Sigmoid activation of a Burn tensor.

# `smoke_test`

```elixir
@spec smoke_test() :: :ok | {:error, String.t()}
```

Performs a quick smoke test of the ExBurn pipeline.

Returns `:ok` on success or `{:error, reason}` on failure.

# `softmax`

```elixir
@spec softmax(ExBurn.Tensor.t(), non_neg_integer()) :: ExBurn.Tensor.t()
```

Softmax along a dimension.

# `sqrt`

```elixir
@spec sqrt(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Square root of a Burn tensor.

# `sub`

```elixir
@spec sub(ExBurn.Tensor.t(), ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Subtracts two Burn tensors.

# `sum`

```elixir
@spec sum(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Sum of all elements in a Burn tensor.

# `summary`

```elixir
@spec summary() :: String.t()
```

Returns a summary of the ExBurn environment.

# `summary`

```elixir
@spec summary(ExBurn.Model.t()) :: String.t()
```

Returns a summary of the model architecture including parameter count.

# `tensor_numel`

```elixir
@spec tensor_numel(ExBurn.Tensor.t()) :: non_neg_integer()
```

Returns the total number of elements in a Burn tensor.

# `tensor_rank`

```elixir
@spec tensor_rank(ExBurn.Tensor.t()) :: non_neg_integer()
```

Returns the rank (number of dimensions) of a Burn tensor.

# `tensor_shape`

```elixir
@spec tensor_shape(ExBurn.Tensor.t()) :: [non_neg_integer()]
```

Returns the shape of a Burn tensor.

# `tensor_type`

```elixir
@spec tensor_type(ExBurn.Tensor.t()) :: atom()
```

Returns the element type of a Burn tensor.

# `to_cpu`

```elixir
@spec to_cpu(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Moves a Burn tensor to the CPU.

# `to_device`

```elixir
@spec to_device(ExBurn.Model.t(), :cpu | :gpu) :: ExBurn.Model.t()
```

Moves all model parameters to the specified device (`:gpu` or `:cpu`).

# `to_gpu`

```elixir
@spec to_gpu(ExBurn.Tensor.t()) :: ExBurn.Tensor.t()
```

Moves a Burn tensor to the GPU.

# `to_nx`

```elixir
@spec to_nx(ExBurn.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, term()}
```

Converts a Burn tensor to an Nx tensor.

# `to_nx_batch`

```elixir
@spec to_nx_batch([ExBurn.Tensor.t()]) :: {:ok, [Nx.Tensor.t()]} | {:error, term()}
```

Batch converts a list of Burn tensors to Nx tensors.

# `train_step`

```elixir
@spec train_step(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  {float(), ExBurn.Model.t()}
```

Performs a single training step: forward + backward + optimizer update.

Useful for custom training loops. Returns `{loss, updated_model}`.

# `transpose`

```elixir
@spec transpose(ExBurn.Tensor.t(), non_neg_integer(), non_neg_integer()) ::
  ExBurn.Tensor.t()
```

Transposes a Burn tensor.

# `unfreeze`

```elixir
@spec unfreeze(ExBurn.Model.t(), [atom() | String.t()]) :: ExBurn.Model.t()
```

Unfreezes the specified layers so their parameters are updated during training.

# `update_params`

```elixir
@spec update_params(ExBurn.Model.t(), map()) :: ExBurn.Model.t()
```

Returns a new model with updated parameters.

# `version`

```elixir
@spec version() :: String.t()
```

Returns the current ExBurn version.

# `zeros`

```elixir
@spec zeros([non_neg_integer()], atom()) :: ExBurn.Tensor.t()
```

Creates a tensor filled with zeros via Burn.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
