View Source EXGBoost.Training.Callback (EXGBoost v0.5.1)
Callbacks are a mechanism to hook into the training process and perform custom actions.
Callbacks are structs with the following fields:
event
- the event that triggers the callbackfun
- the function to call when the callback is triggeredname
- the name of the callbackinit_state
- the initial state of the callback
The following events are supported:
:before_training
- called before the training starts:after_training
- called after the training ends:before_iteration
- called before each iteration:after_iteration
- called after each iteration
The callback function is called with the following arguments:
state
- the current training state
The callback function should return one of the following:
{:cont, state}
- continue training with the given state{:halt, state}
- stop training with the given state
The following callbacks are provided in the EXGBoost.Training.Callback
module:
lr_scheduler
- sets the learning rate for each iterationearly_stop
- performs early stoppingeval_metrics
- evaluates metrics on the training and evaluation setseval_monitor
- prints evaluation metrics
Callbacks can be added to the training process by passing them to EXGBoost.Training.train/2
.
Example
# Callback to perform setup before training
setup_fn = fn state ->
updated_state = put_in(state, [:meta_vars,:early_stop], %{best: 1, since_last_improvement: 0, mode: :max, patience: 5})
{:cont, updated_state}
end
setup_callback = Callback.new(:before_training, setup_fn)
Summary
Functions
A callback function that performs early stopping.
A callback function that evaluates metrics on the training and evaluation sets.
A callback that sets the learning rate for each iteration.
A callback function that prints evaluation metrics according to a period.
Factory for a new callback with an initial state.
Types
@type event() ::
:before_training | :after_training | :before_iteration | :after_iteration
@type fun() :: (EXGBoost.Training.State.t() -> EXGBoost.Training.State.t())
Functions
A callback function that performs early stopping.
Requires that the following exist in the state
that is passed to the callback:
target
is the metric to monitor for early stopping. It must exist in themetrics
that the state contains.mode
is either:min
or:max
and indicates whether the metric should be minimized or maximized.patience
is the number of iterations to wait for the metric to improve before stopping.since_last_improvement
is the number of iterations since the metric last improved.best
is the best value of the metric seen so far.
A callback function that evaluates metrics on the training and evaluation sets.
Requires that the following exist in the state.meta_vars
that is passed to the callback:
- eval_metrics:
- evals: a list of evaluation sets to evaluate metrics on
- filter: a function that takes a metric name and value and returns true if the metric should be included in the results
A callback that sets the learning rate for each iteration.
Requires that learning_rates
either be a list of learning rates or a function that takes the
iteration number and returns a learning rate. learning_rates
must exist in the state
that
is passed to the callback.
A callback function that prints evaluation metrics according to a period.
Requires that the following exist in the state.meta_vars
that is passed to the callback:
- monitor_metrics:
- period: print metrics every
period
iterations - filter: a function that takes a metric name and value and returns true if the metric should be included in the results
- period: print metrics every
@spec new( event :: event(), fun :: (... -> any()), name :: atom(), init_state :: any() ) :: Callback.t()
Factory for a new callback with an initial state.