View Source Axon.Optimizers (Axon v0.5.1)

Implementations of common gradient-based optimization algorithms.

All of the methods in this module are written in terms of the update methods defined in Axon.Updates. Axon treats optimizers as the tuple:

{init_fn, update_fn}

where init_fn returns an initial optimizer state and update_fn scales input gradients. init_fn accepts a model's parameters and attaches state to each parameter. update_fn accepts gradients, optimizer state, and current model parameters and returns updated optimizer state and gradients.

Custom optimizers are often created via the Axon.Updates API.

example

Example

Consider the following usage of the Adam optimizer in a basic update function (assuming objective and the dataset are defined elsewhere):

defmodule Learning do

  import Nx.Defn

  defn init(params, init_fn) do
    init_fn.(params)
  end

  defn update(params, optimizer_state, inputs, targets, update_fn) do
    {loss, gradient} = value_and_grad(params, &objective(&1, inputs, targets))
    {scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params)
    {Axon.Updates.apply_updates(params, scaled_updates), new_optimizer_state, loss}
  end
end

{model_params, _key} = Nx.Random.uniform(key, shape: {784, 10})
{init_fn, update_fn} = Axon.Optimizers.adam(0.005)

optimizer_state =
  Learning.init(params, init_fn)

{new_params, new_optimizer_state, loss} =
  Learning.update(params, optimizer_state, inputs, targets, update_fn)

For a simpler approach, you can also use optimizers with the training API:

  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005))
  |> Axon.Loop.run(data, epochs: 10, compiler: EXLA)

Link to this section Summary

Link to this section Functions

Link to this function

adabelief(learning_rate \\ 0.001, opts \\ [])

View Source

Adabelief optimizer.

options

Options

  • :b1 - first moment decay. Defaults to 0.9
  • :b2 - second moment decay. Defaults to 0.999
  • :eps - numerical stability term. Defaults to 0.0
  • :eps_root - numerical stability term. Defaults to 1.0e-16

references

References

Link to this function

adagrad(learning_rate \\ 0.001, opts \\ [])

View Source

Adagrad optimizer.

options

Options

  • :eps - numerical stability term. Defaults to 1.0e-7

references

References

Link to this function

adam(learning_rate \\ 0.001, opts \\ [])

View Source

Adam optimizer.

options

Options

  • :b1 - first moment decay. Defaults to 0.9
  • :b2 - second moment decay. Defaults to 0.999
  • :eps - numerical stability term. Defaults to 1.0e-8
  • :eps_root - numerical stability term. Defaults to 1.0e-15

references

References

Link to this function

adamw(learning_rate \\ 0.001, opts \\ [])

View Source

Adam with weight decay optimizer.

options

Options

  • :b1 - first moment decay. Defaults to 0.9
  • :b2 - second moment decay. Defaults to 0.999
  • :eps - numerical stability term. Defaults to 1.0e-8
  • :eps_root - numerical stability term. Defaults to 0.0
  • :decay - weight decay. Defaults to 0.0
Link to this function

lamb(learning_rate \\ 0.01, opts \\ [])

View Source

Lamb optimizer.

options

Options

  • :b1 - first moment decay. Defaults to 0.9
  • :b2 - second moment decay. Defaults to 0.999
  • :eps - numerical stability term. Defaults to 1.0e-8
  • :eps_root - numerical stability term. Defaults to 0.0
  • :decay - weight decay. Defaults to 0.0
  • :min_norm - minimum norm value. Defaults to 0.0

references

References

Link to this function

noisy_sgd(learning_rate \\ 0.01, opts \\ [])

View Source

Noisy SGD optimizer.

options

Options

  • :eta - used to compute variance of noise distribution. Defaults to 0.1
  • :gamma - used to compute variance of noise distribution. Defaults to 0.55
Link to this function

radam(learning_rate \\ 0.001, opts \\ [])

View Source

Rectified Adam optimizer.

options

Options

  • :b1 - first moment decay. Defaults to 0.9
  • :b2 - second moment decay. Defaults to 0.999
  • :eps - numerical stability term. Defaults to 1.0e-8
  • :eps_root - numerical stability term. Defaults to 0.0
  • :threshold - threshold term. Defaults to 5.0

references

References

Link to this function

rmsprop(learning_rate \\ 0.01, opts \\ [])

View Source

RMSProp optimizer.

options

Options

  • :centered - whether to scale by centered root of EMA of squares. Defaults to false
  • :momentum - momentum term. If set, uses SGD with momentum and decay set to value of this term.
  • :nesterov - whether or not to use nesterov momentum. Defaults to false
  • :initial_scale - initial value of EMA. Defaults to 0.0
  • :decay - EMA decay rate. Defaults to 0.9
  • :eps - numerical stability term. Defaults to 1.0e-8
Link to this function

sgd(learning_rate \\ 0.01, opts \\ [])

View Source

SGD optimizer.

options

Options

  • :momentum - momentum term. If set, uses SGD with momentum and decay set to value of this term.
  • :nesterov - whether or not to use nesterov momentum. Defaults to false
Link to this function

yogi(learning_rate \\ 0.01, opts \\ [])

View Source

Yogi optimizer.

options

Options

  • :initial_accumulator_value - initial value for first and second moment. Defaults to 0.0
  • :b1 - first moment decay. Defaults to 0.9
  • :b2 - second moment decay. Defaults to 0.999
  • :eps - numerical stability term. Defaults to 1.0e-8
  • :eps_root - numerical stability term. Defaults to 0.0

references

References