ExBurn (Burn) Integration Guide

Copy Markdown View Source

Dala integrates ExBurn, a bridge between Nx and the Burn deep learning framework (Rust). This enables GPU-accelerated ML/DL training and inference on iOS, Android, and desktop.

Status

v0.3.0 — Full Nx backend, defn compiler, training loop, serving, model management. Training uses numerical gradients (central and batch modes). Burn's autodiff integration planned for a future release.

Architecture

Axon model
   
Nx.Defn graph
   
ExBurn.Defn.Compiler (Nx.Defn.Compiler behaviour)
   
ExBurn.Backend (Nx.Backend behaviour)
   
ExBurn.Nif (Rustler NIF)  ExCubecl (GPU buffers, kernels, pipelines)
   
Burn Autodiff<CubeCL> (Rust)
   
CubeCL kernels
   
Metal (iOS) / Vulkan (Android) / CUDA  GPU

GPU Backends

PlatformBackendStatus
iOSMetal
AndroidVulkan
macOSMetal
LinuxVulkan
NVIDIACUDA🔜

Quick Start

1. Check Availability

# Is ExBurn loaded?
Dala.ML.Burn.available?()
# true

# Is a GPU available?
Dala.ML.Burn.gpu?()
# true on iOS/Android with GPU support

# What device will be used?
Dala.ML.Burn.default_device()
# :gpu or :cpu

2. Configure

# Set ExBurn as the default Nx backend
Dala.ML.Burn.configure!()

# Or with options
Dala.ML.Burn.configure!(device: :gpu)

Dala.ML.setup/0 auto-configures Burn when available — no manual setup needed in most cases.

3. Tensors via Burn

# All Nx operations now run through Burn
t = Nx.tensor([1.0, 2.0, 3.0])
Nx.add(t, t) |> Nx.to_list()
# [2.0, 4.0, 6.0]

# Direct Burn tensor creation (bypasses Nx for performance)
bt = Dala.ML.Burn.zeros([3, 3], :f32)
bt = Dala.ML.Burn.ones([2, 4], :f32)

# Convert between Nx and Burn
{:ok, bt} = Dala.ML.Burn.from_nx(tensor)
{:ok, tensor} = Dala.ML.Burn.to_nx(bt)

4. Define and Compile a Model

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256, activation: :relu)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10)

compiled = Dala.ML.Burn.compile(model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001
)

5. Train

trained = Dala.ML.Burn.fit(compiled, {train_x, train_y},
  epochs: 10,
  batch_size: 32,
  validation_data: {val_x, val_y}
)

6. Inference

{:ok, predictions} = Dala.ML.Burn.predict(trained, input_tensor)

7. Save / Load

:ok = Dala.ML.Burn.save(trained, "my_model.model")
{:ok, loaded} = Dala.ML.Burn.load(trained, "my_model.model")

Training

Basic Training

model = Dala.ML.Burn.compile(axon_model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001
)

trained = Dala.ML.Burn.fit(model, {inputs, targets},
  epochs: 10,
  batch_size: 32
)

Training with Validation

trained = Dala.ML.Burn.fit(model, {train_x, train_y},
  epochs: 50,
  batch_size: 64,
  validation_data: {val_x, val_y},
  verbose: true
)

Training with Callbacks

callbacks = [
  # Log metrics after each epoch
  Dala.ML.Burn.Training.logging_callback(),

  # Stop if val loss doesn't improve for 5 epochs
  Dala.ML.Burn.Training.early_stopping_callback(5, 1.0e-4),

  # Save checkpoint every 10 epochs
  Dala.ML.Burn.Training.checkpoint_callback(10, "checkpoints/"),

  # Report progress to a LiveView screen via handle_info
  Dala.ML.Burn.Training.screen_callback(self())
]

trained = Dala.ML.Burn.fit(model, {train_x, train_y},
  epochs: 100,
  batch_size: 32,
  validation_data: {val_x, val_y},
  callbacks: callbacks
)

Handle screen progress updates in your LiveView or GenServer:

def handle_info({:training_progress, epoch, loss, val_loss}, socket) do
  {:noreply, assign(socket,
    epoch: epoch,
    loss: loss,
    val_loss: val_loss
  )}
end

Training with History Tracking

{trained, history} = Dala.ML.Burn.Training.fit_with_progress(
  model, {train_x, train_y},
  epochs: 50,
  batch_size: 32,
  validation_data: {val_x, val_y}
)

# history => [%{epoch: 1, loss: 0.5, val_loss: 0.4}, ...]

Learning Rate Schedules

# Step decay: halve LR every 10 epochs
Dala.ML.Burn.fit(model, data,
  lr_schedule: {:step, 0.001, 10, 0.5}
)

# Exponential decay
Dala.ML.Burn.fit(model, data,
  lr_schedule: {:exponential, 0.001, 0.95}
)

# Cosine annealing
Dala.ML.Burn.fit(model, data,
  lr_schedule: {:cosine, 0.001, 1.0e-5}
)

Gradient Clipping

Dala.ML.Burn.fit(model, data,
  clip_norm: 1.0,    # Clip by max norm
  clip_value: 0.5    # Clip by max absolute value
)

Loss Functions

Supported loss functions:

LossDescription
:cross_entropyCategorical cross-entropy (with log-softmax stability)
:mseMean squared error
:binary_cross_entropyBinary cross-entropy (with numerical clamping)

Optimizers

OptimizerOptions
:adambeta1: 0.9, beta2: 0.999, epsilon: 1.0e-8
:sgdmomentum: 0.9
:rmspropdecay: 0.9, epsilon: 1.0e-8

Evaluation

avg_loss = Dala.ML.Burn.evaluate(model, {test_x, test_y})
# 0.234

Model Summary

IO.puts(Dala.ML.Burn.summary(model))
# ╔══════════════════════════════════════════════════════════╗
# ║                   ExBurn Model Summary                  ║
# ╠══════════════════════════════════════════════════════════╣
# ║  Total params:                                    235146 ║
# ║  Trainable params:                                235146 ║
# ║  Non-trainable:                                        0 ║
# ║  Formatted:                                      235.1K ║
# ╠══════════════════════════════════════════════════════════╣

Serving (Production Inference)

For production use, wrap your model in an Nx.Serving for batched, concurrent inference:

# Build a serving
serving = Dala.ML.Burn.Serving.build(trained_model,
  batch_size: 16,
  batch_timeout: 100
)

# Run single inference
output = Dala.ML.Burn.Serving.run(serving, input_tensor)

# Or supervise it in your app tree
children = [
  {Nx.Serving,
   serving: Dala.ML.Burn.Serving.build(trained_model, batch_size: 32),
   name: :my_model_serving}
]

# Or use the convenience helper
{:ok, _pid} = Dala.ML.Burn.Serving.supervise(trained_model,
  name: :my_model_serving,
  supervisor: MyApp.DynamicSupervisor
)

# Then use it from anywhere
output = Nx.Serving.run(:my_model_serving, input_tensor)

Unified API

Dala.ML.predict/2 dispatches to Burn when given an ExBurn.Model:

# CoreML model (string identifier on iOS)
Dala.ML.predict("my_model", %{"input" => [1.0, 2.0]})

# ONNX session (integer session ID)
Dala.ML.predict(session_id, input_binary)

# Axon model ({model, params} tuple)
Dala.ML.predict({axon_model, params}, input_tensor)

# ExBurn model (ExBurn.Model struct)
Dala.ML.predict(exburn_model, input_tensor)

Benchmarking

# Benchmark current backend
Dala.ML.benchmark(size: 100, iterations: 10)
# %{
#   time_ms: 1.234,
#   gflops: 0.857,
#   backend: {EMLX.Backend, [device: :gpu]},
#   burn: %{time_ms: 0.567, gflops: 1.234}  # if ExBurn available
# }

Platform Notes

iOS

  • Uses Metal GPU backend via Burn's CubeCL
  • No JIT required (unlike EMLX on devices)
  • Training small models (< 10M params) is feasible
  • Inference is the primary use case

Android

  • Uses Vulkan GPU backend via Burn's CubeCL
  • Same training/inference capabilities as iOS

Desktop (Development)

  • Uses Metal (macOS) or Vulkan (Linux)
  • CUDA support planned

Training on Mobile — Caveats

Burn's Autodiff backend is memory-intensive. On iOS/Android with limited RAM:

  • Fine-tuning small models (< 10M parameters) is feasible on modern devices
  • Full training of large models is not recommended on mobile
  • Inference is the primary use case for mobile deployment
  • Minimum recommended: 4GB RAM, A12+ chip (iOS) / Snapdragon 700+ (Android)

The training loop in ExBurn currently uses numerical gradients. Burn's autodiff integration is planned for a future release.

Comparison with Other Dala ML Backends

BackendBest ForGPUTrainingiOSAndroid
EMLXiOS inferenceMetal (MLX)
CoreMLiOS Neural EngineANE
ONNXCross-platformNNAPI/CoreML
GPU ComputeCustom kernelsCubeCLN/A
ExBurnTraining + inferenceCubeCL

Error Handling

All operations raise ExBurn.Error with structured context:

raise ExBurn.Error,
  op: :matmul,
  reason: "shape mismatch",
  details: %{lhs: [3, 4], rhs: [5, 6]}

Troubleshooting

available?() returns false

  • Ensure ex_burn is in your deps: {:ex_burn, "~> 0.3"}
  • Run mix deps.get && mix compile
  • The Rust NIF will be compiled automatically via Rustler

gpu?() returns false

  • ExBurn checks ExCubecl availability for GPU detection
  • On iOS/Android, ensure the GPU compute libraries are linked
  • On desktop, GPU may not be available — falls back to CPU

Training is slow

  • Current training uses numerical gradients (finite differences)
  • For faster training, use EMLX or cloud training and deploy to device
  • Reduce batch size and model size for mobile

Out of memory during training

See Also