View Source Axon.LossScale (Axon v0.3.1)

Implementations of loss-scalers for use in mixed precision training.

Loss scaling is used to prevent underflow when using mixed precision during the model training process. Each loss-scale implementation here returns a 3-tuple of the functions:

{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.power(2, 15))

You can use these to scale/unscale loss and gradients as well as adjust the loss scale state.

Axon.Loop.trainer/3 builds loss-scaling in by default. You can reference the Axon.Loop.train_step/3 implementation to see how loss-scaling is applied in practice.

Link to this section Summary

Functions

Implements dynamic loss-scale.

Implements identity loss-scale.

Implements static loss-scale.

Link to this section Functions

Link to this function

dynamic(loss_scale \\ 32768, opts \\ [])

View Source

Implements dynamic loss-scale.

Implements identity loss-scale.

Link to this function

static(loss_scale \\ 32768)

View Source

Implements static loss-scale.