View Source EXGBoost.Training.Callback (EXGBoost v0.5.1)

Callbacks are a mechanism to hook into the training process and perform custom actions.

Callbacks are structs with the following fields:

  • event - the event that triggers the callback
  • fun - the function to call when the callback is triggered
  • name - the name of the callback
  • init_state - the initial state of the callback

The following events are supported:

  • :before_training - called before the training starts
  • :after_training - called after the training ends
  • :before_iteration - called before each iteration
  • :after_iteration - called after each iteration

The callback function is called with the following arguments:

  • state - the current training state

The callback function should return one of the following:

  • {:cont, state} - continue training with the given state
  • {:halt, state} - stop training with the given state

The following callbacks are provided in the EXGBoost.Training.Callback module:

  • lr_scheduler - sets the learning rate for each iteration
  • early_stop - performs early stopping
  • eval_metrics - evaluates metrics on the training and evaluation sets
  • eval_monitor - prints evaluation metrics

Callbacks can be added to the training process by passing them to EXGBoost.Training.train/2.

Example

# Callback to perform setup before training
setup_fn = fn state ->
  updated_state = put_in(state, [:meta_vars,:early_stop], %{best: 1, since_last_improvement: 0, mode: :max, patience: 5})
  {:cont, updated_state}
end

setup_callback = Callback.new(:before_training, setup_fn)

Summary

Functions

A callback function that performs early stopping.

A callback function that evaluates metrics on the training and evaluation sets.

A callback that sets the learning rate for each iteration.

A callback function that prints evaluation metrics according to a period.

Factory for a new callback with an initial state.

Types

@type event() ::
  :before_training | :after_training | :before_iteration | :after_iteration
@type fun() :: (EXGBoost.Training.State.t() -> EXGBoost.Training.State.t())

Functions

A callback function that performs early stopping.

Requires that the following exist in the state that is passed to the callback:

  • target is the metric to monitor for early stopping. It must exist in the metrics that the state contains.
  • mode is either :min or :max and indicates whether the metric should be minimized or maximized.
  • patience is the number of iterations to wait for the metric to improve before stopping.
  • since_last_improvement is the number of iterations since the metric last improved.
  • best is the best value of the metric seen so far.

A callback function that evaluates metrics on the training and evaluation sets.

Requires that the following exist in the state.meta_vars that is passed to the callback:

  • eval_metrics:
    • evals: a list of evaluation sets to evaluate metrics on
    • filter: a function that takes a metric name and value and returns true if the metric should be included in the results

A callback that sets the learning rate for each iteration.

Requires that learning_rates either be a list of learning rates or a function that takes the iteration number and returns a learning rate. learning_rates must exist in the state that is passed to the callback.

A callback function that prints evaluation metrics according to a period.

Requires that the following exist in the state.meta_vars that is passed to the callback:

  • monitor_metrics:
    • period: print metrics every period iterations
    • filter: a function that takes a metric name and value and returns true if the metric should be included in the results
Link to this function

new(event, fun, name, init_state \\ %{})

View Source
@spec new(
  event :: event(),
  fun :: (... -> any()),
  name :: atom(),
  init_state :: any()
) :: Callback.t()

Factory for a new callback with an initial state.