View Source Axon.MixedPrecision (Axon v0.7.0)

Utilities for creating mixed precision policies.

Mixed precision is useful for increasing model throughput at the possible price of a small dip in accuracy. When creating a mixed precision policy, you define the policy for params, compute, and output.

The params policy dictates what type parameters should be stored as during training. The compute policy dictates what type should be used during intermediate computations in the model's forward pass. The output policy dictates what type the model should output.

Here's an example of creating a mixed precision policy and applying it to a model:

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(64, activation: :relu)
  |> Axon.batch_norm()
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

policy = Axon.MixedPrecision.create_policy(
  params: {:f, 32},
  compute: {:f, 16},
  output: {:f, 32}
)

mp_model =
  model
  |> Axon.MixedPrecision.apply_policy(policy, except: [:batch_norm])

The example above applies the mixed precision policy to every layer in the model except Batch Normalization layers. The policy will cast parameters and inputs to {:f, 16} for intermediate computations in the model's forward pass before casting the output back to {:f, 32}.

Summary

Functions

Casts the given container according to the given policy and type.

Creates a mixed precision policy with the given options.

Functions

Link to this function

cast(policy, tensor_or_container, variable_type)

View Source

Casts the given container according to the given policy and type.

Examples

iex> policy = Axon.MixedPrecision.create_policy(params: {:f, 16})
iex> params = %{"dense" => %{"kernel" => Nx.tensor([1.0, 2.0, 3.0])}}
iex> params = Axon.MixedPrecision.cast(policy, params, :params)
iex> Nx.type(params["dense"]["kernel"])
{:f, 16}

iex> policy = Axon.MixedPrecision.create_policy(compute: {:bf, 16})
iex> value = Nx.tensor([1.0, 2.0, 3.0])
iex> value = Axon.MixedPrecision.cast(policy, value, :compute)
iex> Nx.type(value)
{:bf, 16}

iex> policy = Axon.MixedPrecision.create_policy(output: {:bf, 16})
iex> value = Nx.tensor([1.0, 2.0, 3.0])
iex> value = Axon.MixedPrecision.cast(policy, value, :output)
iex> Nx.type(value)
{:bf, 16}

Note that integers are never promoted to floats:

iex> policy = Axon.MixedPrecision.create_policy(output: {:f, 16})
iex> value = Nx.tensor([1, 2, 3], type: :s64)
iex> value = Axon.MixedPrecision.cast(policy, value, :params)
iex> Nx.type(value)
{:s, 64}
Link to this function

create_policy(opts \\ [])

View Source

Creates a mixed precision policy with the given options.

The default policy nil dictates that no casting will be done.

Options

  • params - parameter precision policy. Defaults to nil
  • compute - compute precision policy. Defaults to nil
  • output - output precision policy. Defaults to nil

Examples

iex> Axon.MixedPrecision.create_policy(params: {:f, 16}, output: {:f, 16})
#Axon.MixedPrecision.Policy<p=f16 o=f16>

iex> Axon.MixedPrecision.create_policy(compute: {:bf, 16})
#Axon.MixedPrecision.Policy<c=bf16>

iex> Axon.MixedPrecision.create_policy()
#Axon.MixedPrecision.Policy<>