View Source Bumblebee.Text.GenerationConfig (Bumblebee v0.6.0)

A set of configuration options controlling text generation.

This struct is expected by Bumblebee.Text.Generation.build_generate/3.

Configuration

Options controlling length

  • :max_new_tokens - the maximum number of tokens to be generated, ignoring the number of tokens in the prompt. Defaults to 20

  • :min_new_tokens - the minimum number of tokens to be generated, ignoring the number of tokens in the prompt

  • :max_length - the maximum length of the sequence to be generated. Note that this length includes the length of the input prompt (including padding). In general, prefer :max_new_tokens, which ignores the number of tokens in the prompt

  • :min_length - the minimum length of the sequence to be generated. Note that this length includes the length of the input prompt (including padding). In general, prefer :min_new_tokens, which ignores the number of tokens in the prompt

Options controlling strategy

  • :strategy - the method deciding how tokens are selected, it has a significant impact on the quality of the generated sequence. Should be a map with :type and strategy-specific options.

    • :greedy_search - the most straightforward approach, where in every iteration the most probable token (as given by the model) is taken.

      Example: %{type: :greedy_search}.

    • :contrastive_search - state-of-the-art decoding method, capable of producing high quality, coherent sequences. The results are deterministic. See this article for more details.

      • :top_k (required) - the number of highest probability vocabulary tokens considered as a continuation

      • :alpha (required) - the weight of degeneration penalty. It balances the model confidence and the penalty

      Example: %{type: :contrastive_search, top_k: 4, alpha: 0.6}.

    • :multinomial_sampling - this method samples tokens according to the probability distribution given by the model. The results are nondeterministic, unless a seed is specified.

      • :top_k (optional) - when specified, restricts sampling to top-k most probable candidates

      • :top_p (optional) - when specified, restricts sampling to tokens which probabilities add up to top-p

    . Defaults to %{type: :greedy_search}

Options controlling generated tokens

  • :decoder_start_token_id - the id of the initial token when generating from scratch, in case of encoder-decoder models

  • :forced_bos_token_id - the id of the token to force as the first generated token

  • :forced_eos_token_id - the id of the token to force as the last generated token when :max_length is reached

  • :forced_token_ids - a list of {index, token_id} pairs forcing token_id to appear at index in the generated sequence. Defaults to []

  • :suppressed_token_ids - a list of token ids to suppress during generation. Defaults to []

  • :no_repeat_ngram_length - when set, n-grams of the given length can occur only once in the generated sequence

  • :temperature - enables exponential scaling of the output probability distribution. The temperature value effectively determines the randomness of the predicted tokens. Values smaller than 1.0 decrease the randomness, while bigger values increase it. Note that this is only relevant for generation :strategy that does sampling based on the output probability distribution

Special tokens used during generation

  • :bos_token_id - the id of the beginning-of-sequence token

  • :eos_token_id - the id of the end-of-sequence token. This can also be a list, in case multiple tokens should be recognized as EOS

  • :pad_token_id - the id of the padding token

Summary

Types

@type t() :: %Bumblebee.Text.GenerationConfig{
  bos_token_id: term(),
  decoder_start_token_id: term(),
  eos_token_id: term(),
  extra_config: term(),
  forced_bos_token_id: term(),
  forced_eos_token_id: term(),
  forced_token_ids: term(),
  max_length: term(),
  max_new_tokens: term(),
  min_length: term(),
  min_new_tokens: term(),
  no_repeat_ngram_length: term(),
  pad_token_id: term(),
  strategy: term(),
  suppressed_token_ids: term(),
  temperature: term()
}