View Source Axon.Updates (Axon v0.3.1)

Parameter update methods.

Update methods transform the input tensor in some way, usually by scaling or shifting the input with respect to some input state. Update methods are composed to create more advanced optimization methods such as AdaGrad or Adam. Each update returns a tuple:

{init_fn, update_fn}

Which represent a state initialization and state update function respectively. While each method in the Updates API is a regular Elixir function, the two methods they return are implemented as defn, so they can be accelerated using any Nx backend or compiler.

Update methods are just combinators that can be arbitrarily composed to create complex optimizers. For example, the Adam optimizer in Axon.Optimizers is implemented as:

def adam(learning_rate, opts \\ []) do
  Updates.scale_by_adam(opts)
  |> Updates.scale(-learning_rate)
end

Updates are maps of updates, often associated with parameters of the same names. Using Axon.Updates.apply_updates/3 will merge updates and parameters by adding associated parameters and updates, and ensuring any given model state is preserved.

custom-combinators

Custom combinators

You can create your own combinators using the stateless/2 and stateful/3 primitives. Every update method in this module is implemented in terms of one of these two primitives.

stateless/2 represents a stateless update:

def scale(combinator \\ Axon.Updates.identity(), step_size) do
  stateless(combinator, &apply_scale(&1, &2, step_size))
end

defnp apply_scale(x, _params, step) do
  transform(
    {x, step},
    fn {updates, step} ->
      deep_new(updates, fn x -> Nx.multiply(x, step) end)
    end
  )
end

Notice how the function given to stateless/2 is defined within defn. This is what allows the anonymous functions returned by Axon.Updates to be used inside defn.

stateful/3 represents a stateful update and follows the same pattern:

def my_stateful_update(updates) do
  Axon.Updates.stateful(updates, &init_my_update/1, &apply_my_update/2)
end

defnp init_my_update(params) do
  state = zeros_like(params)
  %{state: state}
end

defnp apply_my_update(updates, state) do
  new_state = deep_new(state, fn v -> Nx.add(v, 0.01) end)
  updates = transform({updates, new_state}, fn {updates, state} ->
    deep_merge(updates, state, fn g, z -> Nx.multiply(g, z) end)
  end)
  {updates, %{state: new_state}}
end

State associated with individual parameters should have keys that match the keys of the parameter. For example, if you have parameters %{kernel: kernel} with associated states mu and nu representing the first and second moments, your state should look something like:

%{
  mu: %{kernel: kernel_mu}
  nu: %{kernel: kernel_nu}
}

Link to this section Summary

Functions

Adds decayed weights to updates.

Adds random Gaussian noise to the input.

Applies updates to params and updates state parameters with given state map.

Centralizes input by shifting updates by their mean.

Clips input between -delta and delta.

Clips input using input global norm.

Composes two updates. This is useful for extending optimizers without having to reimplement them. For example, you can implement gradient centralization

Returns the identity update.

Scales input by a fixed step size.

Scales input according to Adam algorithm.

Scales input according to the AdaBelief algorithm.

Scale input according to the Rectified Adam algorithm.

Scales input by the root of the EMA of squared inputs.

Scales input by the root of all prior squared inputs.

Scales input using the given schedule function.

Scales input by a tunable learning rate which can be manipulated by external APIs such as Axon's Loop API.

Scales input by the root of the centered EMA of squared inputs.

Scale input according to the Yogi algorithm.

Represents a stateful update.

Represents a stateless update.

Trace inputs with past inputs.

Link to this section Functions

Link to this function

add_decayed_weights(combinator_or_opts \\ [])

View Source

Adds decayed weights to updates.

Commonly used as a regularization strategy.

options

Options

* `:decay` - Rate of decay. Defaults to `0.0`.
Link to this function

add_decayed_weights(combinator, opts)

View Source
Link to this function

add_noise(combinator_or_opts \\ [])

View Source

Adds random Gaussian noise to the input.

options

Options

* `:eta` - Controls amount of noise to add.
  Defaults to `0.01`.

* `:gamma` - Controls amount of noise to add.
  Defaults to `0.55`.
Link to this function

add_noise(combinator, opts)

View Source
Link to this function

apply_updates(params, updates, state \\ nil)

View Source

Applies updates to params and updates state parameters with given state map.

Link to this function

centralize(combinator_or_opts \\ [])

View Source

Centralizes input by shifting updates by their mean.

Link to this function

centralize(combinator, opts)

View Source
Link to this function

clip(combinator_or_opts \\ [])

View Source

Clips input between -delta and delta.

options

Options

  • :delta - maximum absolute value of the input. Defaults to 2.0
Link to this function

clip_by_global_norm(combinator_or_opts \\ [])

View Source

Clips input using input global norm.

options

Options

  • :max_norm - maximum norm value of input. Defaults to 1.0
Link to this function

clip_by_global_norm(combinator, opts)

View Source

Composes two updates. This is useful for extending optimizers without having to reimplement them. For example, you can implement gradient centralization:

import Axon.Updates

Axon.Updates.compose(Axon.Updates.centralize(), Axon.Optimizers.rmsprop())

This is equivalent to:

Axon.Updates.centralize()
|> Axon.Updates.scale_by_rms()

Returns the identity update.

This is often as the initial update in many functions in this module.

Link to this function

scale(combinator \\ identity(), step_size)

View Source

Scales input by a fixed step size.

$$f(x_i) = \alpha x_i$$

Link to this function

scale_by_adam(combinator_or_opts \\ [])

View Source

Scales input according to Adam algorithm.

options

Options

* `:b1` - first moment decay. Defaults to `0.9`

* `:b2` - second moment decay. Defaults to `0.999`

* `:eps` - numerical stability term. Defaults to `1.0e-8`

* `:eps_root` - numerical stability term. Defaults to `1.0e-15`

references

References

Link to this function

scale_by_adam(combinator, opts)

View Source
Link to this function

scale_by_belief(combinator_or_opts \\ [])

View Source

Scales input according to the AdaBelief algorithm.

options

Options

* `:b1` - first moment decay. Defaults to `0.9`.

* `:b2` - second moment decay. Defaults to `0.999`.

* `:eps` - numerical stability term. Defaults to `0.0`.

* `:eps_root` - numerical stability term. Defaults to `1.0e-16`.

references

References

Link to this function

scale_by_belief(combinator, opts)

View Source
Link to this function

scale_by_radam(combinator_or_opts \\ [])

View Source

Scale input according to the Rectified Adam algorithm.

options

Options

* `:b1` - first moment decay. Defaults to `0.9`

* `:b2` - second moment decay. Defaults to `0.999`

* `:eps` - numerical stability term. Defaults to `1.0e-8`

* `:eps_root` - numerical stability term. Defaults to `0.0`

* `:threshold` - threshold for variance. Defaults to `5.0`

references

References

Link to this function

scale_by_radam(combinator, opts)

View Source
Link to this function

scale_by_rms(combinator_or_opts \\ [])

View Source

Scales input by the root of the EMA of squared inputs.

options

Options

* `:decay` - EMA decay rate. Defaults to `0.9`.

* `:eps` - numerical stability term. Defaults to `1.0e-8`.

references

References

Link to this function

scale_by_rms(combinator, opts)

View Source
Link to this function

scale_by_rss(combinator_or_opts \\ [])

View Source

Scales input by the root of all prior squared inputs.

options

Options

* `:eps` - numerical stability term. Defaults to `1.0e-7`
Link to this function

scale_by_rss(combinator, opts)

View Source
Link to this function

scale_by_schedule(combinator \\ identity(), schedule_fn)

View Source

Scales input using the given schedule function.

This can be useful for implementing learning rate schedules. The number of update iterations is tracked by an internal counter. You might need to update the schedule to operate on per-batch schedule rather than per-epoch.

Link to this function

scale_by_state(combinator_or_step)

View Source

Scales input by a tunable learning rate which can be manipulated by external APIs such as Axon's Loop API.

$$f(x_i) = \alpha x_i$$

Link to this function

scale_by_state(combinator, step)

View Source
Link to this function

scale_by_stddev(combinator_or_opts \\ [])

View Source

Scales input by the root of the centered EMA of squared inputs.

options

Options

* `:decay` - EMA decay rate. Defaults to `0.9`.

* `:eps` - numerical stability term. Defaults to `1.0e-8`.

references

References

Link to this function

scale_by_stddev(combinator, opts)

View Source
Link to this function

scale_by_trust_ratio(combinator_or_opts \\ [])

View Source

Scale by trust ratio.

options

Options

* `:min_norm` - Min norm to clip. Defaults to
  `0.0`.

* `:trust_coefficient` - Trust coefficient. Defaults
  to `1.0`.

* `:eps` - Numerical stability term. Defaults to `0.0`.
Link to this function

scale_by_trust_ratio(combinator, opts)

View Source
Link to this function

scale_by_yogi(combinator_or_opts \\ [])

View Source

Scale input according to the Yogi algorithm.

options

Options

* `:initial_accumulator_value` - Initial state accumulator value.

* `:b1` - first moment decay. Defaults to `0.9`

* `:b2` - second moment decay. Defaults to `0.999`

* `:eps` - numerical stability term. Defaults to `1.0e-8`

* `:eps_root` - numerical stability term. Defaults to `0.0`

references

References

* [Adaptive Methods for Nonconvex Optimization](https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf)
Link to this function

scale_by_yogi(combinator, opts)

View Source
Link to this function

stateful(arg \\ identity(), init_fn, apply_fn)

View Source

Represents a stateful update.

Stateful updates require some update state, such as momentum or RMS of previous updates. Therefore you must implement some initialization function as well as an update function.

Link to this function

stateless(arg \\ identity(), apply_fn)

View Source

Represents a stateless update.

Stateless updates do not depend on an update state and thus only require an implementation of an update function.

Link to this function

trace(combinator_or_opts \\ [])

View Source

Trace inputs with past inputs.

options

Options

  • :decay - decay rate for tracing past updates. Defaults to 0.9
  • :nesterov - whether to use Nesterov momentum. Defaults to false