Model definition and training orchestration for ExBurn.
This module provides a high-level API for defining, compiling, and training neural network models using Axon with the ExBurn GPU backend.
Usage
# Define a model
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(256)
|> Axon.relu()
|> Axon.dropout(rate: 0.2)
|> Axon.dense(128)
|> Axon.relu()
|> Axon.dense(10)
# Compile with ExBurn backend
compiled = ExBurn.Model.compile(model, loss: :cross_entropy, optimizer: :adam)
# Forward pass on GPU
{:ok, output} = ExBurn.Model.forward(compiled, input_tensor)
# Train
ExBurn.Model.fit(compiled, train_data, epochs: 10, batch_size: 32)GPU Compilation
The forward/3 function uses Nx.Defn.jit with ExBurn.Defn.Compiler
to execute the model on the GPU via Burn. Parameters are bound to the
expression graph and compiled through the defn compiler for optimal
GPU kernel fusion.
Summary
Functions
Benchmarks the model's forward pass on the given input.
Creates a deep copy of the model with identical parameters and configuration.
Compiles an Axon model for training with the ExBurn backend.
Computes the loss between predictions and targets.
Deserializes parameters from a binary.
Exports the model parameters to a portable format.
Runs a forward pass through the model using the ExBurn GPU defn compiler.
Returns the output shape information from the Axon model for inspection.
Freezes the specified layers so their parameters are not updated during training.
Checks whether a layer is frozen.
Returns the set of frozen layer names.
Imports model parameters from a file saved with export/2.
Returns a map with model information.
Loads model parameters from a file.
Returns the model's loss function.
Returns the model's optimizer.
Returns the current model parameters.
Runs a forward pass through the model using Axon's default backend.
Quantizes model parameters to a lower precision type.
Saves the model parameters to a file using compressed Erlang term format.
Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.
Returns a detailed summary of the model architecture.
Moves all model parameters to the specified device.
Unfreezes the specified layers so their parameters are updated during training.
Returns a new model with updated parameters.
Returns the model's weight decay coefficient.
Types
Functions
@spec benchmark(t(), Nx.Tensor.t(), keyword()) :: map()
Benchmarks the model's forward pass on the given input.
Runs the forward pass multiple times and returns timing statistics.
Parameters
model— A compiledExBurn.Modelstructinput— AnNx.Tensorinput batchopts— Options:warmup— Number of warmup runs (default: 3):runs— Number of benchmarked runs (default: 10)
Returns
A map with :avg_ms, :min_ms, :max_ms, :median_ms, :std_ms.
Example
result = ExBurn.Model.benchmark(model, input, warmup: 5, runs: 20)
IO.puts("Average: #{result.avg_ms}ms")
Creates a deep copy of the model with identical parameters and configuration.
Useful for creating model snapshots during training or for ensemble methods.
Example
snapshot = ExBurn.Model.clone(model)
Compiles an Axon model for training with the ExBurn backend.
Options
:loss— Loss function::cross_entropy,:mse,:binary_cross_entropy(default::cross_entropy):optimizer— Optimizer::adam,:sgd,:rmsprop(default::adam):learning_rate— Learning rate (default: 0.001):device— Device::cpuor:gpu(default::gpu):weight_decay— L2 regularization coefficient (default: 0.0)
Returns
An ExBurn.Model struct ready for training.
@spec compute_loss(t(), Nx.Tensor.t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
Computes the loss between predictions and targets.
Supports :cross_entropy (with log-softmax numerical stability),
:mse, and :binary_cross_entropy.
When :weight_decay is set on the model, L2 regularization is added.
Deserializes parameters from a binary.
Exports the model parameters to a portable format.
Currently supports:
:elixir_terms— Compressed Erlang term format (default, portable):json— JSON format (human-readable, larger)
Example
ExBurn.Model.export(model, "/tmp/model.json", format: :json)
@spec forward(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
Runs a forward pass through the model using the ExBurn GPU defn compiler.
This function binds the model parameters and input to the Axon expression
graph, then executes it via Nx.Defn.jit with ExBurn.Defn.Compiler,
which compiles the computation to run on the GPU through Burn.
Parameters
model— A compiledExBurn.Modelstructinput— AnNx.Tensorinput batch
Returns
{:ok, output_tensor} or {:error, reason}
Returns the output shape information from the Axon model for inspection.
This is useful for debugging and understanding the model architecture without running actual data through it.
Returns
A map with :output_shape and :output_type keys.
Freezes the specified layers so their parameters are not updated during training.
Parameters
model— A compiledExBurn.Modelstructlayer_names— List of layer name strings or atoms to freeze
Returns
A new ExBurn.Model struct with the specified layers frozen.
Example
frozen_model = ExBurn.Model.freeze(model, ["dense_0", "dense_1"])
Checks whether a layer is frozen.
Returns the set of frozen layer names.
Imports model parameters from a file saved with export/2.
Example
{:ok, model} = ExBurn.Model.import_params(model, "/tmp/model.etf")
Returns a map with model information.
Includes parameter count, layer count, loss function, optimizer, device, and memory estimate.
Example
info = ExBurn.Model.info(model)
IO.puts("Parameters: #{info.total_params}")
Loads model parameters from a file.
Returns the model's loss function.
Returns the model's optimizer.
Returns the current model parameters.
@spec predict(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
Runs a forward pass through the model using Axon's default backend.
This is the CPU/Elixir fallback. For GPU execution, use forward/2.
Quantizes model parameters to a lower precision type.
Useful for reducing model size and speeding up inference on devices
with limited compute. Currently supports :f16 (half precision) and
:bf16 (brain float 16).
Parameters
model— A compiledExBurn.Modelstructdtype— Target dtype::f16or:bf16
Returns
A new ExBurn.Model struct with quantized parameters.
Examples
quantized_model = ExBurn.Model.quantize(model, :f16)
Saves the model parameters to a file using compressed Erlang term format.
Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.
Returns a detailed summary of the model architecture.
Shows a Keras/PyTorch-style table with layer names, types, output shapes, and parameter counts.
Moves all model parameters to the specified device.
Parameters
model— A compiledExBurn.Modelstructdevice—:gpuor:cpu
Returns
A new ExBurn.Model struct with parameters on the target device.
Unfreezes the specified layers so their parameters are updated during training.
Parameters
model— A compiledExBurn.Modelstructlayer_names— List of layer name strings or atoms to unfreeze
Returns
A new ExBurn.Model struct with the specified layers unfrozen.
Returns a new model with updated parameters.
Useful for optimizer steps, parameter averaging, or loading externally-computed parameters.
Example
updated_model = ExBurn.Model.update_params(model, new_params)
Returns the model's weight decay coefficient.