# `ExBurn.Model`
[🔗](https://github.com/ohhi-vn/ex_burn/blob/main/lib/ex_burn/model.ex#L1)

Model definition and training orchestration for ExBurn.

This module provides a high-level API for defining, compiling, and
training neural network models using Axon with the ExBurn GPU backend.

## Usage

    # Define a model
    model =
      Axon.input("input", shape: {nil, 784})
      |> Axon.dense(256)
      |> Axon.relu()
      |> Axon.dropout(rate: 0.2)
      |> Axon.dense(128)
      |> Axon.relu()
      |> Axon.dense(10)

    # Compile with ExBurn backend
    compiled = ExBurn.Model.compile(model, loss: :cross_entropy, optimizer: :adam)

    # Forward pass on GPU
    {:ok, output} = ExBurn.Model.forward(compiled, input_tensor)

    # Train
    ExBurn.Model.fit(compiled, train_data, epochs: 10, batch_size: 32)

## GPU Compilation

The `forward/3` function uses `Nx.Defn.jit` with `ExBurn.Defn.Compiler`
to execute the model on the GPU via Burn. Parameters are bound to the
expression graph and compiled through the defn compiler for optimal
GPU kernel fusion.

# `t`

```elixir
@type t() :: %ExBurn.Model{
  axon_model: Axon.ModelState.t(),
  compiled: boolean(),
  device: :cpu | :gpu,
  frozen_layers: MapSet.t(),
  loss_fn: atom(),
  optimizer: atom(),
  optimizer_state: map(),
  params: map(),
  weight_decay: float()
}
```

# `benchmark`

```elixir
@spec benchmark(t(), Nx.Tensor.t(), keyword()) :: map()
```

Benchmarks the model's forward pass on the given input.

Runs the forward pass multiple times and returns timing statistics.

## Parameters

  * `model` — A compiled `ExBurn.Model` struct
  * `input` — An `Nx.Tensor` input batch
  * `opts` — Options
    * `:warmup` — Number of warmup runs (default: 3)
    * `:runs` — Number of benchmarked runs (default: 10)

## Returns

  A map with `:avg_ms`, `:min_ms`, `:max_ms`, `:median_ms`, `:std_ms`.

## Example

    result = ExBurn.Model.benchmark(model, input, warmup: 5, runs: 20)
    IO.puts("Average: #{result.avg_ms}ms")

# `clone`

```elixir
@spec clone(t()) :: t()
```

Creates a deep copy of the model with identical parameters and configuration.

Useful for creating model snapshots during training or for ensemble methods.

## Example

    snapshot = ExBurn.Model.clone(model)

# `compile`

```elixir
@spec compile(
  Axon.ModelState.t() | Axon.t(),
  keyword()
) :: t()
```

Compiles an Axon model for training with the ExBurn backend.

## Options

  * `:loss` — Loss function: `:cross_entropy`, `:mse`, `:binary_cross_entropy` (default: `:cross_entropy`)
  * `:optimizer` — Optimizer: `:adam`, `:sgd`, `:rmsprop` (default: `:adam`)
  * `:learning_rate` — Learning rate (default: 0.001)
  * `:device` — Device: `:cpu` or `:gpu` (default: `:gpu`)
  * `:weight_decay` — L2 regularization coefficient (default: 0.0)

## Returns

  An `ExBurn.Model` struct ready for training.

# `compute_loss`

```elixir
@spec compute_loss(t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, String.t()}
```

Computes the loss between predictions and targets.

Supports `:cross_entropy` (with log-softmax numerical stability),
`:mse`, and `:binary_cross_entropy`.

When `:weight_decay` is set on the model, L2 regularization is added.

# `deserialize_params`

```elixir
@spec deserialize_params(binary()) :: {:ok, map()} | {:error, String.t()}
```

Deserializes parameters from a binary.

# `export`

```elixir
@spec export(t(), Path.t(), keyword()) :: :ok | {:error, String.t()}
```

Exports the model parameters to a portable format.

Currently supports:
  * `:elixir_terms` — Compressed Erlang term format (default, portable)
  * `:json` — JSON format (human-readable, larger)

## Example

    ExBurn.Model.export(model, "/tmp/model.json", format: :json)

# `forward`

```elixir
@spec forward(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
```

Runs a forward pass through the model using the ExBurn GPU defn compiler.

This function binds the model parameters and input to the Axon expression
graph, then executes it via `Nx.Defn.jit` with `ExBurn.Defn.Compiler`,
which compiles the computation to run on the GPU through Burn.

## Parameters

  * `model` — A compiled `ExBurn.Model` struct
  * `input` — An `Nx.Tensor` input batch

## Returns

  `{:ok, output_tensor}` or `{:error, reason}`

# `forward_pattern`

```elixir
@spec forward_pattern(t()) :: %{output_shape: tuple() | nil, output_type: atom()}
```

Returns the output shape information from the Axon model for inspection.

This is useful for debugging and understanding the model architecture
without running actual data through it.

## Returns

  A map with `:output_shape` and `:output_type` keys.

# `freeze`

```elixir
@spec freeze(t(), [atom() | String.t()]) :: t()
```

Freezes the specified layers so their parameters are not updated during training.

## Parameters

  * `model` — A compiled `ExBurn.Model` struct
  * `layer_names` — List of layer name strings or atoms to freeze

## Returns

  A new `ExBurn.Model` struct with the specified layers frozen.

## Example

    frozen_model = ExBurn.Model.freeze(model, ["dense_0", "dense_1"])

# `frozen?`

```elixir
@spec frozen?(t(), atom() | String.t()) :: boolean()
```

Checks whether a layer is frozen.

# `frozen_layers`

```elixir
@spec frozen_layers(t()) :: MapSet.t()
```

Returns the set of frozen layer names.

# `import_params`

```elixir
@spec import_params(t(), Path.t(), keyword()) :: {:ok, t()} | {:error, String.t()}
```

Imports model parameters from a file saved with `export/2`.

## Example

    {:ok, model} = ExBurn.Model.import_params(model, "/tmp/model.etf")

# `info`

```elixir
@spec info(t()) :: map()
```

Returns a map with model information.

Includes parameter count, layer count, loss function, optimizer,
device, and memory estimate.

## Example

    info = ExBurn.Model.info(model)
    IO.puts("Parameters: #{info.total_params}")

# `load`

```elixir
@spec load(t(), Path.t()) :: {:ok, t()} | {:error, String.t()}
```

Loads model parameters from a file.

# `loss_function`

```elixir
@spec loss_function(t()) :: atom()
```

Returns the model's loss function.

# `new`

# `optimizer`

```elixir
@spec optimizer(t()) :: atom()
```

Returns the model's optimizer.

# `parameters`

```elixir
@spec parameters(t()) :: map()
```

Returns the current model parameters.

# `predict`

```elixir
@spec predict(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
```

Runs a forward pass through the model using Axon's default backend.

This is the CPU/Elixir fallback. For GPU execution, use `forward/2`.

# `quantize`

```elixir
@spec quantize(t(), :f16 | :bf16) :: t()
```

Quantizes model parameters to a lower precision type.

Useful for reducing model size and speeding up inference on devices
with limited compute. Currently supports `:f16` (half precision) and
`:bf16` (brain float 16).

## Parameters

  * `model` — A compiled `ExBurn.Model` struct
  * `dtype` — Target dtype: `:f16` or `:bf16`

## Returns

  A new `ExBurn.Model` struct with quantized parameters.

## Examples

    quantized_model = ExBurn.Model.quantize(model, :f16)

# `save`

```elixir
@spec save(t(), Path.t()) :: :ok | {:error, String.t()}
```

Saves the model parameters to a file using compressed Erlang term format.

# `serialize_params`

```elixir
@spec serialize_params(t()) :: binary()
```

Serializes parameters to a binary for network transfer or storage.
Uses compressed Erlang term format.

# `summary`

```elixir
@spec summary(t()) :: String.t()
```

Returns a detailed summary of the model architecture.

Shows a Keras/PyTorch-style table with layer names, types,
output shapes, and parameter counts.

# `to_device`

```elixir
@spec to_device(t(), :cpu | :gpu) :: t()
```

Moves all model parameters to the specified device.

## Parameters

  * `model` — A compiled `ExBurn.Model` struct
  * `device` — `:gpu` or `:cpu`

## Returns

  A new `ExBurn.Model` struct with parameters on the target device.

# `unfreeze`

```elixir
@spec unfreeze(t(), [atom() | String.t()]) :: t()
```

Unfreezes the specified layers so their parameters are updated during training.

## Parameters

  * `model` — A compiled `ExBurn.Model` struct
  * `layer_names` — List of layer name strings or atoms to unfreeze

## Returns

  A new `ExBurn.Model` struct with the specified layers unfrozen.

# `update_params`

```elixir
@spec update_params(t(), map()) :: t()
```

Returns a new model with updated parameters.

Useful for optimizer steps, parameter averaging, or loading
externally-computed parameters.

## Example

    updated_model = ExBurn.Model.update_params(model, new_params)

# `weight_decay`

```elixir
@spec weight_decay(t()) :: float()
```

Returns the model's weight decay coefficient.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
