View Source Axon.Loop (Axon v0.6.1)

Abstraction for modeling a reduction of a dataset with an accumulated state for a number of epochs.

Inspired heavily by PyTorch Ignite.

The main abstraction is the %Axon.Loop{} struct, which controls a nested reduction of the form:

Enum.reduce(1..max_epochs, state, fn epoch, state ->
  Enum.reduce(data, state, &batch_step/2)
end)

data is assumed to be an Enumerable or Stream of input data which is handled by a processing function, batch_step. The purpose of the loop abstraction is to take away much of the boilerplate code used in solving machine learning tasks. Tasks such as normalizing a dataset, hyperparameter optimization, or training machine learning models boil down to writing one function:

defn batch_step(batch, state) do
  # ...do something with batch...
  updated_state
end

For tasks such as training a neural network, state will encapsulate things such as model and optimizer state. For supervised learning tasks, batch_step might look something like:

defn batch_step({inputs, targets}, state) do
  %{parameters: params, optimizer_state: optim_state} = state

  gradients = grad(params, objective_fn.(&1, inputs, targets))
  {updates, new_optim_state} = optimizer.(optim_state, params, gradients)

  new_params = apply_updates(params, updates)

  %{parameters: new_params, optimizer_state: optim_state}
end

batch_step takes a batch of {input, target} pairs and the current state, and updates the model parameters based on the gradients received from some arbitrary objective function. This function will run in a nested loop, iterating over the entire dataset for N epochs before finally returning the trained model state. By defining 1 function, we've created a training loop that works for most machine learning models.

In actuality, the loop abstraction accumulates a struct, %Axon.Loop.State{}, which looks like (assuming container is a generic Elixir container of tensors, e.g. map, tuple, etc.):

%Axon.Loop.State{
  epoch: integer(),
  max_epoch: integer(),
  iteration: integer(),
  max_iteration: integer(),
  metrics: map(string(), container()),
  times: map(integer(), integer()),
  step_state: container()
}

batch_step takes in the batch and the step state field and returns a step_state, which is a generic container of state accumulated at each iteration. The rest of the fields in the state struct are updated automatically behind the scenes.

The loop must start from some initial step state, thus most tasks must also provide an additional initialization function to provide some starting point for the step state. For machine learning tasks, the initialization function will return things like initial model parameters and optimizer state.

Typically, the final output of the loop is the accumulated final state; however, you may optionally apply an output transform to extract specific values at the end of the loop. For example, Axon.Loop.trainer/4 by default extracts trained model state:

output_transform = fn state ->
  state.step_state[:model_state]
end

Initialize and Step

The core of the Axon loop are the init and step functions. The initialization is an arity-0 function which provides an initial step state:

init = fn ->
  %{params: Axon.init(model)}
end

While the step function is the batch_step function mentioned earlier:

step = fn data, state ->
  new_state = # ...do something...
  new_state
end

Note that any optimization and training anonymous functions that need to be used in the batch_step function can be passed as extra arguments. For example:

step_with_training_arguments = fn data, state, optimizer_update_fn, state_update_fn ->
  # ...do something...
end

step = &(step_with_training_arguments.(&1, &2, actual_optimizer_update_fn, actual_state_update_fn))

Metrics

Often times you want to compute metrics associated with your training iterations. To accomplish this, you can attach metrics to each Axon.Loop. Assuming a batch_step function which looks like:

defn batch_step({inputs, targets}, state) do
  %{parameters: params, optimizer_state: optim_state} = state

  gradients = grad(params, objective_fn.(&1, inputs, targets))
  {updates, new_optim_state} = optimizer.(optim_state, params, gradients)

  new_params = apply_updates(params, updates)

  # Shown for simplicity, you can optimize this by calculating preds
  # along with the gradient calculation
  preds = model_fn.(params, inputs)

  %{
    y_true: targets,
    y_pred: preds,
    parameters: new_params,
    optimizer_state: optim_state
  }
end

You can attach metrics to this by using Axon.Loop.metric/4:

Axon.Loop.loop(&batch_step/2)
|> Axon.Loop.metric("Accuracy", :accuracy, fn %{y_true: y_, y_pred: y} -> [y_, y] end)
|> Axon.Loop.run(data)

Because metrics work directly on step_state, you typically need to provide an output transform to indicate which values should be passed to your metric function. By default, Axon assumes a supervised training task with the fields :y_true and :y_pred present in the step state. See Axon.Loop.metric/4 for more information.

Metrics will be tracked in the loop state using the user-provided key. Metrics integrate seamlessly with the supervised metrics defined in Axon.Metrics. You can also use metrics to keep running averages of some values in the original dataset.

Events and Handlers

You can instrument several points in the loop using event handlers. By default, several events are fired when running a loop:

events = [
  :started,             # After loop state initialization
  :epoch_started,       # On epoch start
  :iteration_started,   # On iteration start
  :iteration_completed, # On iteration complete
  :epoch_completed,     # On epoch complete
  :epoch_halted,        # On epoch halt, if early halted
]

You can attach event handlers to events using Axon.Loop.handle_event/4:

loop
|> Axon.Loop.handle_event(:iteration_completed, &log_metrics/1, every: 100)
|> Axon.Loop.run(data)

The above will trigger log_metrics/1 every 100 times the :iteration_completed event is fired. Event handlers must return a tuple {status, state}, where status is an atom with one of the following values:

:continue   # Continue epoch, continue looping
:halt_epoch # Halt the epoch, continue looping
:halt_loop  # Halt looping

And state is an updated Axon.Loop.State struct. Handler functions take as input the current loop state.

It's important to note that event handlers are triggered in the order they are attached to the loop. If you have two handlers on the same event, they will trigger in order:

loop
|> Axon.Loop.handle_event(:epoch_completed, &normalize_state/1) # Runs first
|> Axon.Loop.handle_event(:epoch_completed, &log_state/1) # Runs second

You may provide filters to filter when event handlers trigger. See Axon.Loop.handle_event/4 for more details on valid filters.

Factories

Axon loops are typically created from one of the factory functions provided in this module:

* `Axon.Loop.loop/3` - Creates a loop from step function and optional initialization
functions and output transform functions.

* `Axon.Loop.trainer/3` - Creates a supervised training loop from model, loss, and
optimizer.

* `Axon.Loop.evaluator/1` - Creates a supervised evaluator loop from model.

Running loops

In order to execute a loop, you should use Axon.Loop.run/3:

Axon.Loop.run(loop, data, epochs: 10)

Resuming loops

At times you may want to resume a loop from some previous state. You can accomplish this with Axon.Loop.from_state/2:

loop
|> Axon.Loop.from_state(state)
|> Axon.Loop.run(data)

Summary

Functions

Adds a handler function which saves loop checkpoints on a given event, optionally with metric-based criteria.

Deserializes loop state from a binary.

Adds a handler function which halts a loop if the given metric does not improve between events.

Creates a supervised evaluation step from a model and model state.

Creates a supervised evaluator from a model.

Attaches state to the given loop in order to resume looping from a previous state.

Adds a handler function to the loop which will be triggered on event with an optional filter.

Adds a handler function which updates a Kino.VegaLite plot.

Adds a handler function which logs the given message produced by message_fn to the given IO device every event satisfying filter.

Creates a loop from step_fn, an optional init_fn, and an optional output_transform.

Adds a handler function which monitors the given metric and fires some action when the given metric meets some criteria.

Adds a handler function which reduces the learning rate by the given factor if the given metric does not improve between events.

Runs the given loop on data with the given options.

Serializes loop state to a binary for saving and loading loop from previous states.

Creates a supervised train step from a model, loss function, and optimizer.

Creates a supervised training loop from a model, loss function, and optimizer.

Adds a handler function which tests the performance of model against the given validation set.

Functions

Link to this function

checkpoint(loop, opts \\ [])

View Source

Adds a handler function which saves loop checkpoints on a given event, optionally with metric-based criteria.

By default, loop checkpoints will be saved at the end of every epoch in the current working directory under the checkpoint/ path. Checkpoints are serialized representations of loop state obtained from Axon.Loop.serialize_state/2. Serialization options will be forwarded to Axon.Loop.serialize_state/2.

You can customize checkpoint events by passing :event and :filter options:

loop
|> Axon.Loop.checkpoint(event: :iteration_completed, filter: [every: 50])

Checkpoints are saved under the checkpoint/ directory with a pattern of checkpoint_{epoch}.ckpt. You can customize the path and pattern with the :path and :file_pattern options:

my_file_pattern =
  fn %Axon.Loop.State{epoch: epoch, iteration: iter} ->
    "checkpoint_#{epoch}_#{iter}"
  end

loop
|> Axon.Loop.checkpoint(path: "my_checkpoints", file_pattern: my_file_pattern)

If you'd like to only save checkpoints based on some metric criteria, you can specify the :criteria option. :criteria must be a valid key in metrics:

loop
|> Axon.Loop.checkpoint(criteria: "validation_loss")

The default criteria mode is :min, meaning the min score metric will be considered "best" when deciding to save on a given event. Valid modes are :min and :max:

loop
|> Axon.Loop.checkpoint(criteria: "validation_accuracy", mode: :max)

Options

  • :event - event to fire handler on. Defaults to :epoch_completed.

  • :filter - event filter to attach to handler. Defaults to :always.

  • :patience - number of given events to wait for improvement. Defaults to 3.

  • :mode - whether given metric is being minimized or maximized. One of :min, :max or an arity-1 function which returns true or false. Defaults to :min.

  • :path - path to directory to save checkpoints. Defaults to checkpoint

  • :file_pattern - arity-1 function which returns a string file pattern based on the current loop state. Defaults to saving checkpoints to files checkpoint_#{epoch}_#{iteration}.ckpt.

Link to this function

deserialize_state(serialized, opts \\ [])

View Source

Deserializes loop state from a binary.

It is the opposite of Axon.Loop.serialize_state/2.

By default, the step state is deserialized using Nx.deserialize.2; however, this behavior can be changed if step state is an application specific container. For example, if you introduce your own data structure into step_state and you customized the serialization logic, Nx.deserialize/2 will not be sufficient for deserialization. - you must pass custom logic with :deserialize_step_state.

Link to this function

early_stop(loop, monitor, opts \\ [])

View Source

Adds a handler function which halts a loop if the given metric does not improve between events.

By default, this will run after each epoch and track the improvement of a given metric.

You must specify a metric to monitor and the metric must be present in the loop state. Typically, this will be a validation metric:

model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(val_data)
|> Axon.Loop.early_stop("validation_accuracy")

It's important to remember that handlers are executed in the order they are added to the loop. For example, if you'd like to checkpoint a loop after every epoch and use early stopping, most likely you want to add the checkpoint handler before the early stopping handler:

model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.checkpoint()
|> Axon.Loop.early_stop("accuracy")

That will ensure checkpoint is always fired, even if the loop exited early.

Creates a supervised evaluation step from a model and model state.

This function is intended for more fine-grained control over the loop creation process. It returns a tuple of {init_fn, step_fn} where init_fn returns an initial step state and step_fn performs a single evaluation step.

Creates a supervised evaluator from a model.

An evaluator can be used for things such as testing and validation of models after or during training. It assumes model is an Axon struct, container of structs, or a tuple of init / apply functions. model_state must be a container usable from within model.

The evaluator returns a step state of the form:

%{
  y_true: labels,
  y_pred: predictions
}

Such that you can attach any number of supervised metrics to the evaluation loop:

model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric("Accuracy", :accuracy)

You must pass a compatible trained model state to Axon.Loop.run/4 when using supervised evaluation loops. For example, if you've binded the result of a training run to trained_model_state, you can run the trained model through an evaluation run like this:

model
|> Axon.Loop.evaluator()
|> Axon.Loop.run(data, trained_model_state, compiler: EXLA)

This function applies an output transform which returns the map of metrics accumulated over the given loop.

Attaches state to the given loop in order to resume looping from a previous state.

It's important to note that a loop's attached state takes precedence over defined initialization functions. Given initialization function:

defn init_state(), do: %{foo: 1, bar: 2}

And an attached state:

state = %State{step_state: %{foo: 2, bar: 3}}

init_state/0 will never execute, and instead the initial step state of %{foo: 2, bar: 3} will be used.

Link to this function

handle_event(loop, event, handler, filter \\ :always)

View Source

Adds a handler function to the loop which will be triggered on event with an optional filter.

Events take place at different points during loop execution. The default events are:

events = [
  :started,             # After loop state initialization
  :epoch_started,       # On epoch start
  :iteration_started,   # On iteration start
  :iteration_completed, # On iteration complete
  :epoch_completed,     # On epoch complete
  :epoch_halted,        # On epoch halt, if early halted
]

Generally, event handlers are side-effecting operations which provide some sort of inspection into the loop's progress. It's important to note that if you define multiple handlers to be triggered on the same event, they will execute in order from when they were attached to the training loop:

loop
|> Axon.Loop.handle_event(:epoch_started, &normalize_step_state/1) # executes first
|> Axon.Loop.handle_event(:epoch_started, &log_step_state/1) # executes second

Thus, if you have separate handlers which alter or depend on loop state, you need to ensure they are ordered correctly, or combined into a single event handler for maximum control over execution.

event must be an atom representing the event to trigger handler or a list of atoms indicating handler should be triggered on multiple events. event may be :all which indicates the handler should be triggered on every event during loop processing.

handler must be an arity-1 function which takes as input loop state and returns {status, state}, where status is an atom with one of the following values:

:continue   # Continue epoch, continue looping
:halt_epoch # Halt the epoch, continue looping
:halt_loop  # Halt looping

filter is an atom representing a valid filter predicate, a keyword of predicate-value pairs, or a function which takes loop state and returns a true, indicating the handler should run, or false, indicating the handler should not run. Valid predicates are:

:always # Always trigger event
:once   # Trigger on first event firing

Valid predicate-value pairs are:

every: N # Trigger every `N` event
only: N # Trigger on `N` event

Warning: If you modify the step state in an event handler, it will trigger potentially excessive recompilation and result in significant additional overhead during loop execution.

Link to this function

kino_vega_lite_plot(loop, plot, metric, opts \\ [])

View Source

Adds a handler function which updates a Kino.VegaLite plot.

By default, this will run after every iteration.

You must specify a plot to push to and a metric to track. The :x axis will be the iteration count, labeled "step". The metric must match the name given to the :y axis in your VegaLite plot:

plot =
  Vl.new()
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "step", type: :quantitative)
  |> Vl.encode_field(:y, "loss", type: :quantitative)
  |> Kino.VegaLite.new()
  |> Kino.render()

model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.kino_vega_lite_plot(plot, "loss")

Options

  • :event - event to fire handler on. Defaults to :iteration_completed.

  • :filter - event filter to attach to handler. Defaults to :always.

Link to this function

log(loop, message_fn, opts \\ [])

View Source

Adds a handler function which logs the given message produced by message_fn to the given IO device every event satisfying filter.

In most cases, this is useful for inspecting the contents of the loop state at intermediate stages. For example, the default trainer loop factory attaches IO logging of epoch, batch, loss and metrics.

It's also possible to log loop state to files by changing the given IO device. By default, the IO device is :stdio.

message_fn should take the loop state and return a binary representing the message to be written to the IO device.

Link to this function

loop(step_fn, init_fn \\ &default_init/2, output_transform \\ & &1)

View Source

Creates a loop from step_fn, an optional init_fn, and an optional output_transform.

step_fn is an arity-2 function which takes a batch and state and returns an updated step state:

defn batch_step(batch, step_state) do
  step_state + 1
end

init_fn by default is an identity function which forwards its initial arguments as the model state. You should define a custom initialization function if you require a different behavior:

defn init_step_state(state) do
  Map.merge(%{foo: 1}, state)
end

You may use state in conjunction with initialization functions in init_fn. For example, train_step/3 uses initial state as initial model parameters to allow initializing models from partial parameterizations.

step_batch/2 and init_step_state/1 are typically called from within Nx.Defn.jit/3. While JIT-compilation will work with anonymous functions, def, and defn, it is recommended that you use the stricter defn to define both functions in order to avoid bugs or cryptic errors.

output_transform/1 applies a transformation on the final accumulated loop state. This is useful for extracting specific fields from a loop and piping them into additional functions.

Link to this function

metric(loop, metric, name \\ nil, accumulate \\ :running_average, transform_or_fields \\ [:y_true, :y_pred])

View Source

Adds a metric of the given name to the loop.

A metric is a function which tracks or measures some value with respect to values in the step state. For example, when training classification models, it's common to track the model's accuracy during training:

loop
|> Axon.Loop.metric(:accuracy, "Accuracy")

By default, metrics assume a supervised learning task and extract the fields [:y_true, :y_pred] from the step state. If you wish to work on a different value, you can use an output transform. An output transform is a list of keys to extract from the output state, or a function which returns a flattened list of values to pass to the given metric function. Values received from output transforms are passed to the given metric using:

value = output_transform.(step_state)
apply(metric, value)

Thus, even if you want your metric to work on a container, your output transform must return a list.

metric must be an atom which matches the name of a metric in Axon.Metrics, or an arbitrary function which returns a tensor or container.

name must be a string or atom used to store the computed metric in the loop state. If names conflict, the last attached metric will take precedence:

loop
|> Axon.Loop.metric(:mean_squared_error, "Error") # Will be overwritten
|> Axon.Loop.metric(:mean_absolute_error, "Error") # Will be used

By default, metrics keep a running average of the metric calculation. You can override this behavior by changing accumulate:

loop
|> Axon.Loop.metric(:true_negatives, "tn", :running_sum)

Accumulation function can be one of the accumulation combinators in Axon.Metrics or an arity-3 function of the form: accumulate(acc, obs, i) :: new_acc.

Link to this function

monitor(loop, metric, fun, name, opts \\ [])

View Source

Adds a handler function which monitors the given metric and fires some action when the given metric meets some criteria.

This function is a generalization of handlers such as Axon.Loop.reduce_lr_on_plateau/3 and Axon.Loop.early_stop/3.

You must specify a metric to monitor that is present in the state metrics. This handler will then monitor the value of the metric at the specified intervals and fire the specified function if the criteria is met.

You must also specify a name for the monitor attached to the given metric. This will be used to store metadata associated with the monitor.

The common case of monitor is to track improvement of metrics and take action if metrics haven't improved after a certain number of events. However, you can also set a monitor up to trigger if a metric hits some criteria (such as a threshold) by passing a custom monitoring mode.

Options

  • :event - event to fire handler on. Defaults to :epoch_completed.

  • :filter - event filter to attach to handler. Defaults to :always.

  • :patience - number of given events to wait for improvement. Defaults to 3.

  • :mode - whether given metric is being minimized or maximized. One of :min, :max or an arity-1 function which returns true or false. Defaults to :min.

Link to this function

reduce_lr_on_plateau(loop, monitor, opts \\ [])

View Source

Adds a handler function which reduces the learning rate by the given factor if the given metric does not improve between events.

By default, this will run after each epoch and track the improvement of a given metric.

You must specify a metric to monitor and the metric must be present in the loop state. Typically, this will be a validation metric:

model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, val_data)
|> Axon.Loop.reduce_lr_on_plateau("accuracy", mode: :max)

Options

  • :event - event to fire handler on. Defaults to :epoch_completed.

  • :filter - event filter to attach to handler. Defaults to :always.

  • :patience - number of given events to wait for improvement. Defaults to 3.

  • :mode - whether given metric is being minimized or maximized. Defaults to :min.

  • :factor - factor to decrease learning rate by. Defaults to 0.1.

Link to this function

run(loop, data, init_state \\ %{}, opts \\ [])

View Source

Runs the given loop on data with the given options.

loop must be a valid Axon.Loop struct built from one of the loop factories provided in this module.

data must be an Enumerable or Stream which yields batches of data on each iteration.

Options

  • :epochs - max epochs to run loop for. Must be non-negative integer. Defaults to 1.

  • :iterations - max iterations to run each epoch. Must be non-negative integer. Defaults to -1 or no max iterations.

  • :jit_compile? - whether or not to JIT compile initialization and step functions. JIT compilation must be used for gradient computations. Defaults to true.

  • :garbage_collect - whether or not to garbage collect after each loop iteration. This may prevent OOMs, but it will slow down training.

  • :strict? - whether or not to compile step functions strictly. If this flag is set, the loop will raise on any cache miss during the training loop. Defaults to true.

  • :debug - run loop in debug mode to trace loop progress. Defaults to false.

Additional options are forwarded to Nx.Defn.jit as JIT-options. If no JIT options are set, the default options set with Nx.Defn.default_options are used.

Link to this function

serialize_state(state, opts \\ [])

View Source

Serializes loop state to a binary for saving and loading loop from previous states.

You can consider the serialized state to be a checkpoint of all state at a given iteration and epoch.

By default, the step state is serialized using Nx.serialize/2; however, this behavior can be changed if step state is an application specific container. For example, if you introduce your own data structure into step_state, Nx.serialize/2 will not be sufficient for serialization - you must pass custom serialization as an option with :serialize_step_state.

Additional opts controls serialization options such as compression. It is forwarded to :erlang.term_to_binary/2.

Link to this function

train_step(model, loss, optimizer, opts \\ [])

View Source

Creates a supervised train step from a model, loss function, and optimizer.

This function is intended for more fine-grained control over the loop creation process. It returns a tuple of {init_fn, step_fn} where init_fn is an initialization function which returns an initial step state and step_fn is a supervised train step constructed from model, loss, and optimizer.

model must be an Axon struct, a valid defn container of Axon structs, or a {init_fn, apply_fn}-tuple where init_fn is an arity-2 function which initializes the model state and apply_fn is an arity-2 function which applies the forward pass of the model. The forward pass of the model must return a map with keys :prediction and :state representing the model's prediction and updated state for layers which aggregate state during training.

loss must be an atom which matches a function in Axon.Losses, a list of {loss, weight} tuples representing a basic weighted loss function for multi-output models, or an arity-2 function representing a custom loss function.

optimizer must be an atom matching the name of a valid optimizer in Polaris.Optimizers, or a {init_fn, update_fn} tuple where init_fn is an arity-1 function which initializes the optimizer state from the model parameters and update_fn is an arity-3 function that receives (gradient, optimizer_state, model_parameters) and scales gradient updates with respect to input parameters, optimizer state, and gradients. The update_fn returns {scaled_updates, optimizer_state}, which can then be applied to the model through model_parameters = Axon.Update.apply_updates(model_parameters, scaled_updates). See Polaris.Updates for more information on building optimizers.

Options

  • :seed - seed to use when constructing models. Seed controls random initialization of model parameters. Defaults to no seed which constructs a random seed for you at model build time.

  • :loss_scale - type of loss-scaling to use, if any. Loss-scaling is necessary when doing mixed precision training for numerical stability. Defaults to :identity or no loss-scaling.

  • :gradient_accumulation_steps - number of gradient accumulation steps to take during training. Gradient accumulation decreases the number of updates by accumulating gradients between steps, increasing the effective batch size on smaller devices. Defaults to 1.

Link to this function

trainer(model, loss, optimizer, opts \\ [])

View Source

Creates a supervised training loop from a model, loss function, and optimizer.

This function is useful for training models on most standard supervised learning tasks. It assumes data consists of tuples of input-target pairs, e.g. [{x0, y0}, {x1, y1}, ..., {xN, yN}] where x0 and y0 are batched tensors or containers of batched tensors.

It defines an initialization function which first initializes model state using the given model and then initializes optimizer state using the initial model state. The step function uses a differentiable objective function defined with respect to the model parameters, input data, and target data using the given loss function. It then updates model parameters using the given optimizer in order to minimize loss with respect to the model parameters.

model must be an Axon struct, a valid defn container of Axon structs, or a {init_fn, apply_fn}-tuple where init_fn is an arity-2 function which initializes the model state and apply_fn is an arity-2 function which applies the forward pass of the model.

loss must be an atom which matches a function in Axon.Losses, a list of {loss, weight} tuples representing a basic weighted loss function for multi-output models, or an arity-2 function representing a custom loss function.

optimizer must be an atom matching the name of a valid optimizer in Polaris.Optimizers, or a {init_fn, update_fn} tuple where init_fn is an arity-1 function which initializes the optimizer state from attached parameters and update_fn is an arity-3 function which scales gradient updates with respect to input parameters, optimizer state, and gradients. See Polaris.Updates for more information on building optimizers.

This function creates a step function which outputs a map consisting of the following fields for step_state:

%{
  y_pred: tensor() | container(tensor()), # Model predictions for use in metrics
  y_true: tensor() | container(tensor()), # True labels for use in metrics
  loss: tensor(), # Running average of loss over epoch
  model_state: container(tensor()), # Model parameters and state
  optimizer_state: container(tensor()) # Optimizer state associated with each parameter
}

Examples

Basic usage

data = Stream.zip(input, target)

model = Axon.input("input", shape: {nil, 32}) |> Axon.dense(1, activation: :sigmoid)

model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.run(data)

Customizing Optimizer

model
|> Axon.Loop.trainer(:binary_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.05))
|> Axon.Loop.run(data)

Custom loss

loss_fn = fn y_true, y_pred -> Nx.cos(y_true, y_pred) end

model
|> Axon.Loop.trainer(loss_fn, Polaris.Optimizers.rmsprop(learning_rate: 0.01))
|> Axon.Loop.run(data)

Multiple objectives with multi-output model

model = {Axon.input("input_0", shape: {nil, 1}), Axon.input("input_1", shape: {nil, 2})}
loss_weights = [mean_squared_error: 0.5, mean_absolute_error: 0.5]

model
|> Axon.Loop.trainer(loss_weights, :sgd)
|> Axon.Loop.run(data)

Options

  • :log - training loss and metric log interval. Set to 0 to silence training logs. Defaults to 50

  • :seed - seed to use when constructing models. Seed controls random initialization of model parameters. Defaults to no seed which constructs a random seed for you at model build time.

  • :loss_scale - type of loss-scaling to use, if any. Loss-scaling is necessary when doing mixed precision training for numerical stability. Defaults to :identity or no loss-scaling.

  • :gradient_accumulation_steps - number of gradient accumulation steps to take during training. Gradient accumulation decreases the number of updates by accumulating gradients between steps, increasing the effective batch size on smaller devices. Defaults to 1.

Link to this function

validate(loop, model, validation_data, opts \\ [])

View Source

Adds a handler function which tests the performance of model against the given validation set.

This handler assumes the loop state matches the state initialized in a supervised training loop. Typically, you'd call this immediately after creating a supervised training loop:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.validate(model, validation_data)

Please note that you must pass the same (or an equivalent) model into this method so it can be used during the validation loop. The metrics which are computed are those which are present BEFORE the validation handler was added to the loop. For the following loop:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.metric(:binary_cross_entropy)

only :mean_absolute_error will be computed at validation time.

The returned loop state is altered to contain validation metrics for use in later handlers such as early stopping and model checkpoints. Since the order of execution of event handlers is in the same order they are declared in the training loop, you MUST call this method before any other handler which expects or may use validation metrics.

By default the validation loop runs after every epoch; however, you can customize it by overriding the default event and event filters:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data, event: :iteration_completed, filter: [every: 10_000])
|> Axon.Loop.metric(:binary_cross_entropy)