View Source Axon.MixedPrecision (Axon v0.3.1)

Utilities for creating mixed precision policies.

Mixed precision is useful for increasing model throughput at the possible price of a small dip in accuracy. When creating a mixed precision policy, you define the policy for params, compute, and output.

The params policy dictates what type parameters should be stored as during training. The compute policy dictates what type should be used during intermediate computations in the model's forward pass. The output policy dictates what type the model should output.

Here's an example of creating a mixed precision policy and applying it to a model:

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(64, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

policy = Axon.MixedPrecision.create_policy(
  params: {:f, 32},
  compute: {:f, 16},
  output: {:f, 32}
)

mp_model =
  model
  |> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])

The example above applies the mixed precision policy to every layer in the model except Batch Normalization layers. The policy will cast parameters and inputs to {:f, 16} for intermediate computations in the model's forward pass before casting the output back to {:f, 32}.

Link to this section Summary

Functions

Creates a mixed precision policy with the given options.

Link to this section Functions

Link to this function

create_policy(opts \\ [])

View Source

Creates a mixed precision policy with the given options.

options

Options

  • params - parameter precision policy. Defaults to {:f, 32}
  • compute - compute precision policy. Defaults to {:f, 32}
  • output - output precision policy. Defaults to {:f, 32}

examples

Examples

iex> Axon.MixedPrecision.create_policy(params: {:f, 16}, output: {:f, 16})
%Policy{params: {:f, 16}, compute: {:f, 32}, output: {:f, 16}}

iex> Axon.MixedPrecision.create_policy(compute: {:bf, 16})
%Policy{params: {:f, 32}, compute: {:bf, 16}, output: {:f, 32}}