ExBurn.Model (ex_burn v0.3.0)

Copy Markdown View Source

Model definition and training orchestration for ExBurn.

This module provides a high-level API for defining, compiling, and training neural network models using Axon with the ExBurn GPU backend.

Usage

# Define a model
model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256)
  |> Axon.relu()
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(128)
  |> Axon.relu()
  |> Axon.dense(10)

# Compile with ExBurn backend
compiled = ExBurn.Model.compile(model, loss: :cross_entropy, optimizer: :adam)

# Forward pass on GPU
{:ok, output} = ExBurn.Model.forward(compiled, input_tensor)

# Train
ExBurn.Model.fit(compiled, train_data, epochs: 10, batch_size: 32)

GPU Compilation

The forward/3 function uses Nx.Defn.jit with ExBurn.Defn.Compiler to execute the model on the GPU via Burn. Parameters are bound to the expression graph and compiled through the defn compiler for optimal GPU kernel fusion.

Summary

Functions

Benchmarks the model's forward pass on the given input.

Creates a deep copy of the model with identical parameters and configuration.

Compiles an Axon model for training with the ExBurn backend.

Computes the loss between predictions and targets.

Deserializes parameters from a binary.

Exports the model parameters to a portable format.

Runs a forward pass through the model using the ExBurn GPU defn compiler.

Returns the output shape information from the Axon model for inspection.

Freezes the specified layers so their parameters are not updated during training.

Checks whether a layer is frozen.

Returns the set of frozen layer names.

Imports model parameters from a file saved with export/2.

Returns a map with model information.

Loads model parameters from a file.

Returns the model's loss function.

Returns the model's optimizer.

Returns the current model parameters.

Runs a forward pass through the model using Axon's default backend.

Quantizes model parameters to a lower precision type.

Saves the model parameters to a file using compressed Erlang term format.

Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.

Returns a detailed summary of the model architecture.

Moves all model parameters to the specified device.

Unfreezes the specified layers so their parameters are updated during training.

Returns a new model with updated parameters.

Returns the model's weight decay coefficient.

Types

t()

@type t() :: %ExBurn.Model{
  axon_model: Axon.ModelState.t(),
  compiled: boolean(),
  device: :cpu | :gpu,
  frozen_layers: MapSet.t(),
  loss_fn: atom(),
  optimizer: atom(),
  optimizer_state: map(),
  params: map(),
  weight_decay: float()
}

Functions

benchmark(model, input, opts \\ [])

@spec benchmark(t(), Nx.Tensor.t(), keyword()) :: map()

Benchmarks the model's forward pass on the given input.

Runs the forward pass multiple times and returns timing statistics.

Parameters

  • model — A compiled ExBurn.Model struct
  • input — An Nx.Tensor input batch
  • opts — Options
    • :warmup — Number of warmup runs (default: 3)
    • :runs — Number of benchmarked runs (default: 10)

Returns

A map with :avg_ms, :min_ms, :max_ms, :median_ms, :std_ms.

Example

result = ExBurn.Model.benchmark(model, input, warmup: 5, runs: 20)
IO.puts("Average: #{result.avg_ms}ms")

clone(model)

@spec clone(t()) :: t()

Creates a deep copy of the model with identical parameters and configuration.

Useful for creating model snapshots during training or for ensemble methods.

Example

snapshot = ExBurn.Model.clone(model)

compile(axon_model, opts \\ [])

@spec compile(
  Axon.ModelState.t() | Axon.t(),
  keyword()
) :: t()

Compiles an Axon model for training with the ExBurn backend.

Options

  • :loss — Loss function: :cross_entropy, :mse, :binary_cross_entropy (default: :cross_entropy)
  • :optimizer — Optimizer: :adam, :sgd, :rmsprop (default: :adam)
  • :learning_rate — Learning rate (default: 0.001)
  • :device — Device: :cpu or :gpu (default: :gpu)
  • :weight_decay — L2 regularization coefficient (default: 0.0)

Returns

An ExBurn.Model struct ready for training.

compute_loss(model, pred, target)

@spec compute_loss(t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, String.t()}

Computes the loss between predictions and targets.

Supports :cross_entropy (with log-softmax numerical stability), :mse, and :binary_cross_entropy.

When :weight_decay is set on the model, L2 regularization is added.

deserialize_params(binary)

@spec deserialize_params(binary()) :: {:ok, map()} | {:error, String.t()}

Deserializes parameters from a binary.

export(model, path, opts \\ [])

@spec export(t(), Path.t(), keyword()) :: :ok | {:error, String.t()}

Exports the model parameters to a portable format.

Currently supports:

  • :elixir_terms — Compressed Erlang term format (default, portable)
  • :json — JSON format (human-readable, larger)

Example

ExBurn.Model.export(model, "/tmp/model.json", format: :json)

forward(model, input)

@spec forward(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}

Runs a forward pass through the model using the ExBurn GPU defn compiler.

This function binds the model parameters and input to the Axon expression graph, then executes it via Nx.Defn.jit with ExBurn.Defn.Compiler, which compiles the computation to run on the GPU through Burn.

Parameters

Returns

{:ok, output_tensor} or {:error, reason}

forward_pattern(model)

@spec forward_pattern(t()) :: %{output_shape: tuple() | nil, output_type: atom()}

Returns the output shape information from the Axon model for inspection.

This is useful for debugging and understanding the model architecture without running actual data through it.

Returns

A map with :output_shape and :output_type keys.

freeze(model, layer_names)

@spec freeze(t(), [atom() | String.t()]) :: t()

Freezes the specified layers so their parameters are not updated during training.

Parameters

  • model — A compiled ExBurn.Model struct
  • layer_names — List of layer name strings or atoms to freeze

Returns

A new ExBurn.Model struct with the specified layers frozen.

Example

frozen_model = ExBurn.Model.freeze(model, ["dense_0", "dense_1"])

frozen?(model, layer_name)

@spec frozen?(t(), atom() | String.t()) :: boolean()

Checks whether a layer is frozen.

frozen_layers(model)

@spec frozen_layers(t()) :: MapSet.t()

Returns the set of frozen layer names.

import_params(model, path, opts \\ [])

@spec import_params(t(), Path.t(), keyword()) :: {:ok, t()} | {:error, String.t()}

Imports model parameters from a file saved with export/2.

Example

{:ok, model} = ExBurn.Model.import_params(model, "/tmp/model.etf")

info(model)

@spec info(t()) :: map()

Returns a map with model information.

Includes parameter count, layer count, loss function, optimizer, device, and memory estimate.

Example

info = ExBurn.Model.info(model)
IO.puts("Parameters: #{info.total_params}")

load(model, path)

@spec load(t(), Path.t()) :: {:ok, t()} | {:error, String.t()}

Loads model parameters from a file.

loss_function(model)

@spec loss_function(t()) :: atom()

Returns the model's loss function.

new(fields \\ [])

optimizer(model)

@spec optimizer(t()) :: atom()

Returns the model's optimizer.

parameters(model)

@spec parameters(t()) :: map()

Returns the current model parameters.

predict(model, input)

@spec predict(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}

Runs a forward pass through the model using Axon's default backend.

This is the CPU/Elixir fallback. For GPU execution, use forward/2.

quantize(model, dtype)

@spec quantize(t(), :f16 | :bf16) :: t()

Quantizes model parameters to a lower precision type.

Useful for reducing model size and speeding up inference on devices with limited compute. Currently supports :f16 (half precision) and :bf16 (brain float 16).

Parameters

  • model — A compiled ExBurn.Model struct
  • dtype — Target dtype: :f16 or :bf16

Returns

A new ExBurn.Model struct with quantized parameters.

Examples

quantized_model = ExBurn.Model.quantize(model, :f16)

save(model, path)

@spec save(t(), Path.t()) :: :ok | {:error, String.t()}

Saves the model parameters to a file using compressed Erlang term format.

serialize_params(model)

@spec serialize_params(t()) :: binary()

Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.

summary(model)

@spec summary(t()) :: String.t()

Returns a detailed summary of the model architecture.

Shows a Keras/PyTorch-style table with layer names, types, output shapes, and parameter counts.

to_device(model, device)

@spec to_device(t(), :cpu | :gpu) :: t()

Moves all model parameters to the specified device.

Parameters

  • model — A compiled ExBurn.Model struct
  • device:gpu or :cpu

Returns

A new ExBurn.Model struct with parameters on the target device.

unfreeze(model, layer_names)

@spec unfreeze(t(), [atom() | String.t()]) :: t()

Unfreezes the specified layers so their parameters are updated during training.

Parameters

  • model — A compiled ExBurn.Model struct
  • layer_names — List of layer name strings or atoms to unfreeze

Returns

A new ExBurn.Model struct with the specified layers unfrozen.

update_params(model, new_params)

@spec update_params(t(), map()) :: t()

Returns a new model with updated parameters.

Useful for optimizer steps, parameter averaging, or loading externally-computed parameters.

Example

updated_model = ExBurn.Model.update_params(model, new_params)

weight_decay(model)

@spec weight_decay(t()) :: float()

Returns the model's weight decay coefficient.