View Source Axon.LossScale (Axon v0.7.0)

Implementations of loss-scalers for use in mixed precision training.

Loss scaling is used to prevent underflow when using mixed precision during the model training process. Each loss-scale implementation here returns a 3-tuple of the functions:

{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.pow(2, 15))

You can use these to scale/unscale loss and gradients as well as adjust the loss scale state.

Axon.Loop.trainer/3 builds loss-scaling in by default. You can reference the Axon.Loop.train_step/3 implementation to see how loss-scaling is applied in practice.

Summary

Functions

Implements dynamic loss-scale.

Implements identity loss-scale.

Implements static loss-scale.

Functions

Implements dynamic loss-scale.

Implements identity loss-scale.

Implements static loss-scale.