View Source EXGBoost.Training.Callback (EXGBoost v0.5.0)
Callbacks are a mechanism to hook into the training process and perform custom actions.
Callbacks are structs with the following fields:
event- the event that triggers the callbackfun- the function to call when the callback is triggeredname- the name of the callbackinit_state- the initial state of the callback
The following events are supported:
:before_training- called before the training starts:after_training- called after the training ends:before_iteration- called before each iteration:after_iteration- called after each iteration
The callback function is called with the following arguments:
state- the current training state
The callback function should return one of the following:
{:cont, state}- continue training with the given state{:halt, state}- stop training with the given state
The following callbacks are provided in the EXGBoost.Training.Callback module:
lr_scheduler- sets the learning rate for each iterationearly_stop- performs early stoppingeval_metrics- evaluates metrics on the training and evaluation setseval_monitor- prints evaluation metrics
Callbacks can be added to the training process by passing them to EXGBoost.Training.train/2.
Example
# Callback to perform setup before training
setup_fn = fn state ->
updated_state = put_in(state, [:meta_vars,:early_stop], %{best: 1, since_last_improvement: 0, mode: :max, patience: 5})
{:cont, updated_state}
end
setup_callback = Callback.new(:before_training, setup_fn)
Summary
Functions
A callback function that performs early stopping.
A callback function that evaluates metrics on the training and evaluation sets.
A callback that sets the learning rate for each iteration.
A callback function that prints evaluation metrics according to a period.
Factory for a new callback with an initial state.
Types
@type event() ::
:before_training | :after_training | :before_iteration | :after_iteration
@type fun() :: (EXGBoost.Training.State.t() -> EXGBoost.Training.State.t())
Functions
A callback function that performs early stopping.
Requires that the following exist in the state that is passed to the callback:
targetis the metric to monitor for early stopping. It must exist in themetricsthat the state contains.modeis either:minor:maxand indicates whether the metric should be minimized or maximized.patienceis the number of iterations to wait for the metric to improve before stopping.since_last_improvementis the number of iterations since the metric last improved.bestis the best value of the metric seen so far.
A callback function that evaluates metrics on the training and evaluation sets.
Requires that the following exist in the state.meta_vars that is passed to the callback:
- eval_metrics:
- evals: a list of evaluation sets to evaluate metrics on
- filter: a function that takes a metric name and value and returns true if the metric should be included in the results
A callback that sets the learning rate for each iteration.
Requires that learning_rates either be a list of learning rates or a function that takes the
iteration number and returns a learning rate. learning_rates must exist in the state that
is passed to the callback.
A callback function that prints evaluation metrics according to a period.
Requires that the following exist in the state.meta_vars that is passed to the callback:
- monitor_metrics:
- period: print metrics every
perioditerations - filter: a function that takes a metric name and value and returns true if the metric should be included in the results
- period: print metrics every
@spec new( event :: event(), fun :: (... -> any()), name :: atom(), init_state :: any() ) :: Callback.t()
Factory for a new callback with an initial state.