View Source Axon.Loop (Axon v0.3.1)
Abstraction for modeling a reduction of a dataset with an accumulated state for a number of epochs.
Inspired heavily by PyTorch Ignite.
The main abstraction is the %Loop{}
struct, which controls a nested
reduction of the form:
Enum.reduce(1..max_epochs, state, fn epoch, state ->
Enum.reduce(data, state, &batch_step/2)
end)
data
is assumed to be an Enumerable
or Stream
of input data which is
handled by a processing function, batch_step
. The purpose of the loop
abstraction is to take away much of the boilerplate used in solving machine
learning tasks. Tasks such as normalizing a dataset, hyperparameter optimization,
or training machine learning models boil down to writing one function:
defn batch_step(batch, state) do
# ...do something with batch...
updated_state
end
For tasks such as training a neural network, state
will encapsulate things
such as model and optimizer state. For supervised learning tasks, batch_step
might look something like:
defn batch_step({inputs, targets}, state) do
%{parameters: params, optimizer_state: optim_state} = state
gradients = grad(params, objective_fn.(&1, inputs, targets))
{updates, new_optim_state} = optimizer.(optim_state, params, gradients)
new_params = apply_updates(params, updates)
%{parameters: new_params, optimizer_state: optim_state}
end
batch_step
takes a batch of {input, target}
pairs and the current state,
and updates the model parameters based on the gradients received from some arbitrary
objective function. This function will run in a nested loop, iterating over the entire
dataset for N
epochs before finally returning the trained model state. By defining
1 function, we've created a training loop that works for most machine learning models.
In actuality, the loop abstraction accumulates a struct, Axon.Loop.State
, which looks
like (assuming container
is a generic Elixir container of tensors, e.g. map, tuple, etc.):
%State{
epoch: integer(),
max_epoch: integer(),
iteration: integer(),
max_iteration: integer(),
metrics: map(string(), container()),
times: map(integer(), integer()),
step_state: container()
}
batch_step
takes in the batch and the step state field and returns a step_state
,
which is a generic container of state accumulated at each iteration. The rest of the fields
in the state struct are updated automatically behind the scenes.
The loop must start from some initial step state, thus most tasks must also provide an additional initialization function to provide some starting point for the step state. For machine learning tasks, the initialization function will return things like initial model parameters and optimizer state.
Typically, the final output of the loop is the accumulated final state; however, you
may optionally apply an output transform to extract specific values at the end of the
loop. For example, Axon.Loop.trainer/4
by default extracts trained model state:
output_transform = fn state ->
state.step_state[:model_state]
end
initialize-and-step
Initialize and Step
The core of the Axon loop are the init and step functions. The initialization is an arity-0 function which provides an initial step state:
init = fn ->
%{params: Axon.init(model)}
end
While the step function is the batch_step
function mentioned earlier:
step = fn data, state ->
new_state = # ...do something...
new_state
end
metrics
Metrics
Often times you want to compute metrics associated with your training iterations.
To accomplish this, you can attach metrics to each Axon.Loop
. Assuming a batch_step
function which looks like:
defn batch_step({inputs, targets}, state) do
%{parameters: params, optimizer_state: optim_state} = state
gradients = grad(params, objective_fn.(&1, inputs, targets))
{updates, new_optim_state} = optimizer.(optim_state, params, gradients)
new_params = apply_updates(params, updates)
# Shown for simplicity, you can optimize this by calculating preds
# along with the gradient calculation
preds = model_fn.(params, inputs)
%{
y_true: targets,
y_pred: preds,
parameters: new_params,
optimizer_state: optim_state
}
end
You can attach metrics to this by using Axon.Loop.metric/4
:
Axon.Loop.loop(&batch_step/2)
|> Axon.Loop.metric("Accuracy", :accuracy, fn %{y_true: y_, y_pred: y} -> [y_, y] end)
|> Axon.Loop.run(data)
Because metrics work directly on step_state
, you typically need to provide an output
transform to indicate which values should be passed to your metric function. By default,
Axon assumes a supervised training task with the fields :y_true
and :y_pred
present
in the step state. See Axon.Loop.metric/4
for more information.
Metrics will be tracked in the loop state using the user-provided key. Metrics integrate
seamlessly with the supervised metrics defined in Axon.Metrics
. You can also use metrics
to keep running averages of some values in the original dataset.
events-and-handlers
Events and Handlers
You can instrument several points in the loop using event handlers. By default, several events are fired when running a loop:
events = [
:started, # After loop state initialization
:epoch_started, # On epoch start
:iteration_started, # On iteration start
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
:halted, # On loop halt, if early halted
:completed # On loop completion
]
You can attach event handlers to events using Axon.Loop.handle/4
:
loop
|> Axon.Loop.handle(:iteration_completed, &log_metrics/1, every: 100)
|> Axon.Loop.run(data)
The above will trigger log_metrics/1
every 100 times the :iteration_completed
event
is fired. Event handlers must return a tuple {status, state}
, where status
is an
atom with one of the following values:
:continue # Continue epoch, continue looping
:halt_epoch # Halt the epoch, continue looping
:halt_loop # Halt looping
And state
is an updated Axon.Loop.State
struct. Handler functions take as input
the current loop state.
It's important to note that event handlers are triggered in the order they are attached to the loop. If you have two handlers on the same event, they will trigger in order:
loop
|> Axon.Loop.handle(:epoch_completed, &normalize_state/1) # Runs first
|> Axon.Loop.handle(:epoch_completed, &log_state/1) # Runs second
You may provide filters to filter when event handlers trigger. See Axon.Loop.handle/4
for more details on valid filters.
factories
Factories
Axon loops are typically created from one of the factory functions provided in this module:
* `Axon.Loop.loop/3` - Creates a loop from step function and optional initialization
functions and output transform functions.
* `Axon.Loop.trainer/3` - Creates a supervised training loop from model, loss, and
optimizer.
* `Axon.Loop.evaluator/1` - Creates a supervised evaluator loop from model.
running-loops
Running loops
In order to execute a loop, you should use Axon.Loop.run/3
:
loop
|> Axon.Loop.run(data, epochs: 10)
resuming-loops
Resuming loops
At times you may want to resume a loop from some previous state. You can accomplish this
with Axon.Loop.from_state/2
:
loop
|> Axon.Loop.from_state(state)
|> Axon.Loop.run(data)
Link to this section Summary
Functions
Adds a handler function which saves loop checkpoints on a given event, optionally with metric-based criteria.
Deserializes loop state from a binary.
Adds a handler function which halts a loop if the given metric does not improve between events.
Creates a supervised evaluation step from a model and model state.
Creates a supervised evaluator from a model and model state.
Attaches state
to the given loop in order to resume looping
from a previous state.
Adds a handler function to the loop which will be triggered on event
with an optional filter.
Adds a handler function which logs the given message produced
by message_fn
to the given IO device every event
satisfying
filter
.
Creates a loop from step_fn
, an optional init_fn
, and an
optional output_transform
.
Adds a metric of the given name to the loop.
Adds a handler function which reduces the learning rate by the given factor if the given metric does not improve between events.
Runs the given loop on data with the given options.
Serializes loop state to a binary for saving and loading loop from previous states.
Creates a supervised train step from a model, loss function, and optimizer.
Creates a supervised training loop from a model, loss function, and optimizer.
Adds a handler function which tests the performance of model
against the given validation set.
Link to this section Functions
Adds a handler function which saves loop checkpoints on a given event, optionally with metric-based criteria.
By default, loop checkpoints will be saved at the end of every
epoch in the current working directory under the checkpoint/
path. Checkpoints are serialized representations of loop state
obtained from Axon.Loop.serialize_state/2
. Serialization
options will be forwarded to Axon.Loop.serialize_state/2
.
You can customize checkpoint events by passing :event
and :filter
options:
loop
|> Axon.Loop.checkpoint(event: :iteration_completed, filter: [every: 50])
Checkpoints are saved under the checkpoint/
directory with a pattern
of checkpoint_{epoch}.ckpt
. You can customize the path and pattern
with the :path
and :file_pattern
options:
my_file_pattern =
fn %Axon.Loop.State{epoch: epoch, iteration: iter} ->
"checkpoint_#{epoch}_#{iter}"
end
loop
|> Axon.Loop.checkpoint(path: "my_checkpoints", file_pattern: my_file_pattern)
If you'd like to only save checkpoints based on some metric criteria,
you can specify the :criteria
option. :criteria
must be a valid key
in metrics:
loop
|> Axon.Loop.checkpoint(criteria: "validation_loss")
The default criteria mode is :min
, meaning the min score metric will
be considered "best" when deciding to save on a given event. Valid modes
are :min
and :max
:
loop
|> Axon.Loop.checkpoint(criteria: "validation_accuracy", mode: :max)
Deserializes loop state from a binary.
It is the opposite of Axon.Loop.serialize_state/2
.
By default, the step state is deserialized using Nx.deserialize.2
;
however, this behavior can be changed if step state is an application
specific container. For example, if you introduce your own data
structure into step_state and you customized the serialization logic,
Nx.deserialize/2
will not be sufficient for deserialization. - you
must pass custom logic with :deserialize_step_state
.
Adds a handler function which halts a loop if the given metric does not improve between events.
By default, this will run after each epoch and track the improvement of a given metric.
You must specify a metric to monitor and the metric must be present in the loop state. Typically, this will be a validation metric:
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(val_data)
|> Axon.Loop.early_stop("validation_accuracy")
It's important to remember that handlers are executed in the order they are added to the loop. For example, if you'd like to checkpoint a loop after every epoch and use early stopping, most likely you want to add the checkpoint handler before the early stopping handler:
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.checkpoint()
|> Axon.Loop.early_stop("accuracy")
That will ensure checkpoint is always fired, even if the loop exited early.
Creates a supervised evaluation step from a model and model state.
This function is intended for more fine-grained control over the loop
creation process. It returns a tuple of {init_fn, step_fn}
where
init_fn
returns an initial step state and step_fn
performs a
single evaluation step.
Creates a supervised evaluator from a model and model state.
An evaluator can be used for things such as testing and validation of models
after or during training. It assumes model
is an Axon struct, container of
structs, or a tuple of init
/ apply
functions. model_state
must be a
container usable from within model
.
The evaluator returns a step state of the form:
%{
y_true: labels,
y_pred: predictions
}
Such that you can attach any number of supervised metrics to the evaluation loop:
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric("Accuracy", :accuracy)
Applies an output transform which returns the map of metrics accumulated over the given loop.
Attaches state
to the given loop in order to resume looping
from a previous state.
It's important to note that a loop's attached state takes precedence over defined initialization functions. Given initialization function:
defn init_state(), do: %{foo: 1, bar: 2}
And an attached state:
state = %State{step_state: %{foo: 2, bar: 3}}
init_state/0
will never execute, and instead the initial step state
of %{foo: 2, bar: 3}
will be used.
Adds a handler function to the loop which will be triggered on event
with an optional filter.
Events take place at different points during loop execution. The default events are:
events = [
:started, # After loop state initialization
:epoch_started, # On epoch start
:iteration_started, # On iteration start
:iteration_completed, # On iteration complete
:epoch_completed, # On epoch complete
:epoch_halted, # On epoch halt, if early halted
:halted, # On loop halt, if early halted
:completed # On loop completion
]
Generally, event handlers are side-effecting operations which provide some sort of inspection into the loop's progress. It's important to note that if you define multiple handlers to be triggered on the same event, they will execute in order from when they were attached to the training loop:
loop
|> Axon.Loop.handle(:epoch_started, &normalize_step_state/1) # executes first
|> Axon.Loop.handle(:epoch_started, &log_step_state/1) # executes second
Thus, if you have separate handlers which alter or depend on loop state, you need to ensure they are ordered correctly, or combined into a single event handler for maximum control over execution.
event
must be an atom representing the event to trigger handler
or a
list of atoms indicating handler
should be triggered on multiple events.
event
may be :all
which indicates the handler should be triggered on
every event during loop processing.
handler
must be an arity-1 function which takes as input loop state and
returns {status, state}
, where status
is an atom with one of the following
values:
:continue # Continue epoch, continue looping
:halt_epoch # Halt the epoch, continue looping
:halt_loop # Halt looping
filter
is an atom representing a valid filter predicate, a keyword of
predicate-value pairs, or a function which takes loop state and returns
a true
, indicating the handler should run, or false
, indicating the
handler should not run. Valid predicates are:
:always # Always trigger event
:once # Trigger on first event firing
Valid predicate-value pairs are:
every: N # Trigger every `N` event
only: N # Trigger on `N` event
Adds a handler function which logs the given message produced
by message_fn
to the given IO device every event
satisfying
filter
.
In most cases, this is useful for inspecting the contents of
the loop state at intermediate stages. For example, the default
trainer
loop factory attaches IO logging of epoch, batch, loss
and metrics.
It's also possible to log loop state to files by changing the
given IO device. By default, the IO device is :stdio
.
message_fn
should take the loop state and return a binary
representing the message to be written to the IO device.
loop(step_fn, init_fn \\ &default_init/2, output_transform \\ & &1)
View SourceCreates a loop from step_fn
, an optional init_fn
, and an
optional output_transform
.
step_fn
is an arity-2 function which takes a batch and state
and returns an updated step state:
defn batch_step(batch, step_state) do
step_state + 1
end
init_fn
by default is an identity function which forwards its
initial arguments as the model state. You should define a custom
initialization function if you require a different behavior:
defn init_step_state(state) do
Map.merge(%{foo: 1}, state)
end
You may use state
in conjunction with initialization functions in
init_fn
. For example, train_step/3
uses initial state as initial
model parameters to allow initializing models from partial parameterizations.
step_batch/2
and init_step_state/1
are typically called from
within Nx.Defn.jit/3
. While JIT-compilation will work with anonymous functions,
def
, and defn
, it is recommended that you use the stricter defn
to define
both functions in order to avoid bugs or cryptic errors.
output_transform/1
applies a transformation on the final accumulated loop state.
This is useful for extracting specific fields from a loop and piping them into
additional functions.
metric(loop, metric, name \\ nil, accumulate \\ :running_average, transform_or_fields \\ [:y_true, :y_pred])
View SourceAdds a metric of the given name to the loop.
A metric is a function which tracks or measures some value with respect to values in the step state. For example, when training classification models, it's common to track the model's accuracy during training:
loop
|> Axon.Loop.metric(:accuracy, "Accuracy")
By default, metrics assume a supervised learning task and extract the fields
[:y_true, :y_pred]
from the step state. If you wish to work on a different
value, you can use an output transform. An output transform is a list of keys
to extract from the output state, or a function which returns a flattened list
of values to pass to the given metric function. Values received from output
transforms are passed to the given metric using:
value = output_transform.(step_state)
apply(metric, value)
Thus, even if you want your metric to work on a container, your output transform must return a list.
metric
must be an atom which matches the name of a metric in Axon.Metrics
, or
an arbitrary function which returns a tensor or container.
name
must be a string or atom used to store the computed metric in the loop
state. If names conflict, the last attached metric will take precedence:
loop
|> Axon.Loop.metric(:mean_squared_error, "Error") # Will be overwritten
|> Axon.Loop.metric(:mean_absolute_error, "Error") # Will be used
By default, metrics keep a running average of the metric calculation. You can
override this behavior by changing accumulate
:
loop
|> Axon.Loop.metric(:true_negatives, "tn", :running_sum)
Accumulation function can be one of the accumulation combinators in Axon.Metrics
or an arity-3 function of the form: accumulate(acc, obs, i) :: new_acc
.
Adds a handler function which reduces the learning rate by the given factor if the given metric does not improve between events.
By default, this will run after each epoch and track the improvement of a given metric.
You must specify a metric to monitor and the metric must be present in the loop state. Typically, this will be a validation metric:
model
|> Axon.Loop.trainer(loss, optim)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, val_data)
|> Axon.Loop.reduce_lr_on_plateau("accuracy", mode: :max)
options
Options
:event
- event to fire handler on. Defaults to:epoch_completed
.:filter
- event filter to attach to handler. Defaults to:always
.:patience
- number of given events to wait for improvement. Defaults to3
.:mode
- whether given metric is being minimized or maximized. Defaults to:min
.:factor
- factor to decrease learning rate by. Defaults to0.1
.
Runs the given loop on data with the given options.
loop
must be a valid Axon.Loop struct built from one of the
loop factories provided in this module.
data
must be an Enumerable or Stream which yields batches of
data on each iteration.
options
Options
:epochs
- max epochs to run loop for. Must be non-negative integer. Defaults to1
.:iterations
- max iterations to run each epoch. Must be non-negative integer. Defaults to-1
or no max iterations.:jit_compile?
- whether or not to JIT compile initialization and step functions. JIT compilation must be used for gradient computations. Defaults to true.:debug
- run loop in debug mode to trace loop progress. Defaults to false.
Additional options are forwarded to Nx.Defn.jit
as JIT-options. If no JIT
options are set, the default options set with Nx.Defn.default_options
are
used.
Serializes loop state to a binary for saving and loading loop from previous states.
You can consider the serialized state to be a checkpoint of all state at a given iteration and epoch.
By default, the step state is serialized using Nx.serialize/2
;
however, this behavior can be changed if step state is an application
specific container. For example, if you introduce your own data
structure into step_state, Nx.serialize/2
will not be sufficient
for serialization - you must pass custom serialization as an option
with :serialize_step_state
.
Additional opts
controls serialization options such as compression.
It is forwarded to :erlang.term_to_binary/2
.
Creates a supervised train step from a model, loss function, and optimizer.
This function is intended for more fine-grained control over the loop
creation process. It returns a tuple of {init_fn, step_fn}
where init_fn
is an initialization function which returns an initial step state and
step_fn
is a supervised train step constructed from model
, loss
,
and optimizer
.
model
must be an Axon struct, a valid defn container
of Axon structs, or a {init_fn, apply_fn}
-tuple where init_fn
is
an arity-2 function which initializes the model state and apply_fn
is
an arity-2 function which applies the forward pass of the model. The forward
pass of the model must return a map with keys :prediction
and :state
representing the model's prediction and updated state for layers which
aggregate state during training.
loss
must be an atom which matches a function in Axon.Losses
, a list
of {loss, weight}
tuples representing a basic weighted loss function
for multi-output models, or an arity-2 function representing a custom loss
function.
optimizer
must be an atom matching the name of a valid optimizer in Axon.Optimizers
,
or a {init_fn, update_fn}
tuple where init_fn
is an arity-1 function which
initializes the optimizer state from attached parameters and update_fn
is an
arity-3 function which scales gradient updates with respect to input parameters,
optimizer state, and gradients. See Axon.Updates
for more information on building
optimizers.
trainer(model, loss, optimizer, loss_scale \\ :identity, opts \\ [])
View SourceCreates a supervised training loop from a model, loss function, and optimizer.
This function is useful for training models on most standard supervised
learning tasks. It assumes data consists of tuples of input-target pairs,
e.g. [{x0, y0}, {x1, y1}, ..., {xN, yN}]
where x0
and y0
are batched
tensors or containers of batched tensors.
It defines an initialization function which first initializes model state using the given model and then initializes optimizer state using the initial model state. The step function uses a differentiable objective function defined with respect to the model parameters, input data, and target data using the given loss function. It then updates model parameters using the given optimizer in order to minimize loss with respect to the model parameters.
model
must be an Axon struct, a valid defn container
of Axon structs, or a {init_fn, apply_fn}
-tuple where init_fn
is
an arity-2 function which initializes the model state and apply_fn
is
an arity-2 function which applies the forward pass of the model.
loss
must be an atom which matches a function in Axon.Losses
, a list
of {loss, weight}
tuples representing a basic weighted loss function
for multi-output models, or an arity-2 function representing a custom loss
function.
optimizer
must be an atom matching the name of a valid optimizer in Axon.Optimizers
,
or a {init_fn, update_fn}
tuple where init_fn
is an arity-1 function which
initializes the optimizer state from attached parameters and update_fn
is an
arity-3 function which scales gradient updates with respect to input parameters,
optimizer state, and gradients. See Axon.Updates
for more information on building
optimizers.
This function creates a step function which outputs a map consisting of the following
fields for step_state
:
%{
y_pred: tensor() | container(tensor()), # Model predictions for use in metrics
y_true: tensor() | container(tensor()), # True labels for use in metrics
loss: tensor(), # Running average of loss over epoch
model_state: container(tensor()), # Model parameters and state
optimizer_state: container(tensor()) # Optimizer state associated with each parameter
}
examples
Examples
basic-usage
Basic usage
data = Stream.zip(input, target)
model = Axon.input("input", shape: {nil, 32}) |> Axon.dense(1, activation: :sigmoid)
model
|> Axon.Loop.trainer(:binary_cross_entropy, :adam)
|> Axon.Loop.run(data)
customizing-optimizer
Customizing Optimizer
model
|> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.adam(0.05))
|> Axon.Loop.run(data)
custom-loss
Custom loss
loss_fn = fn y_true, y_pred -> Nx.cos(y_true, y_pred) end
model
|> Axon.Loop.trainer(loss_fn, Axon.Optimizers.rmsprop(0.01))
|> Axon.Loop.run(data)
multiple-objectives-with-multi-output-model
Multiple objectives with multi-output model
model = {Axon.input("input_0", shape: {nil, 1}), Axon.input("input_1", shape: {nil, 2})}
loss_weights = [mean_squared_error: 0.5, mean_absolute_error: 0.5]
model
|> Axon.Loop.trainer(loss_weights)
|> Axon.Loop.run(data)
options
Options
:log
- training loss and metric log interval. Set to 0 to silence training logs. Defaults to 50
validate(loop, model, validation_data, event \\ :epoch_completed, filter \\ :always)
View SourceAdds a handler function which tests the performance of model
against the given validation set.
This handler assumes the loop state matches the state initialized in a supervised training loop. Typically, you'd call this immediately after creating a supervised training loop:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.validate(model, validation_data)
Please note that you must pass the same (or an equivalent) model into this method so it can be used during the validation loop. The metrics which are computed are those which are present BEFORE the validation handler was added to the loop. For the following loop:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data)
|> Axon.Loop.metric(:binary_cross_entropy)
only :mean_absolute_error
will be computed at validation time.
The returned loop state is altered to contain validation metrics for use in later handlers such as early stopping and model checkpoints. Since the order of execution of event handlers is in the same order they are declared in the training loop, you MUST call this method before any other handler which expects or may use validation metrics.
By default the validation loop runs after every epoch; however, you can customize it by overriding the default event and event filters:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data, :iteration_completed, every: 10_000)
|> Axon.Loop.metric(:binary_cross_entropy)