Configuration for speculative decoding.
Summary
Functions
Determine the max sequence len for the draft model. This is usually
Python method SpeculativeConfig._validate_suffix_decoding.
Verifies and adjusts the tensor parallel size for a draft model
Python method SpeculativeConfig._verify_args.
WARNING: Whenever a new field is added to this config,
Create a parallel config for use by the draft worker.
Python method SpeculativeConfig.hf_config_override.
Constructs SpeculativeConfig.
Python method SpeculativeConfig.use_eagle.
Types
Functions
@spec _maybe_override_draft_max_model_len( SnakeBridge.Ref.t(), term(), integer(), integer(), keyword() ) :: {:ok, integer()} | {:error, Snakepit.Error.t()}
Determine the max sequence len for the draft model. This is usually
the draft_max_model_len, but may be the target_max_model_len if it is less than the draft_max_model_len, or may be speculative_max_model_len if it is specified.
This is necessary so that sequences do not exceed the capacity of the draft model or the target model.
speculative_max_model_len is mainly used for testing that sequences can skip speculation.
Parameters
speculative_max_model_len(term())draft_max_model_len(integer())target_max_model_len(integer())
Returns
integer()
@spec _validate_suffix_decoding( SnakeBridge.Ref.t(), keyword() ) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method SpeculativeConfig._validate_suffix_decoding.
Returns
term()
@spec _verify_and_get_draft_tp(SnakeBridge.Ref.t(), term(), term(), term(), keyword()) :: {:ok, integer()} | {:error, Snakepit.Error.t()}
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
Parameters
target_parallel_config(term())speculative_draft_tensor_parallel_size(term())draft_hf_config(term())
Returns
integer()
@spec _verify_args( SnakeBridge.Ref.t(), keyword() ) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method SpeculativeConfig._verify_args.
Returns
term()
@spec code_revision(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec compute_hash( SnakeBridge.Ref.t(), keyword() ) :: {:ok, String.t()} | {:error, Snakepit.Error.t()}
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if it affects the computation graph.
Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states.
Returns
String.t()
@spec create_draft_parallel_config(SnakeBridge.Ref.t(), term(), integer(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
Parameters
target_parallel_config(term())speculative_draft_tensor_parallel_size(integer())
Returns
term()
@spec disable_by_batch_size(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec disable_padded_drafter_batch(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec draft_model_config(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec draft_parallel_config(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec draft_tensor_parallel_size(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec enforce_eager(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec hf_config_override(SnakeBridge.Ref.t(), term(), keyword()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
Python method SpeculativeConfig.hf_config_override.
Parameters
hf_config(term())
Returns
term()
@spec max_model_len(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec method(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec model(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec new(term(), term(), term(), keyword()) :: {:ok, SnakeBridge.Ref.t()} | {:error, Snakepit.Error.t()}
Constructs SpeculativeConfig.
Parameters
dataclass_self__(term())args(term())kwargs(term())
@spec num_speculative_tokens(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec prompt_lookup_max(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec prompt_lookup_min(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec quantization(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec revision(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec speculative_token_tree(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec suffix_decoding_max_cached_requests(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec suffix_decoding_max_spec_factor(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec suffix_decoding_max_tree_depth(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec suffix_decoding_min_token_prob(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec target_model_config(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec target_parallel_config(SnakeBridge.Ref.t()) :: {:ok, term()} | {:error, Snakepit.Error.t()}
@spec use_eagle( SnakeBridge.Ref.t(), keyword() ) :: {:ok, boolean()} | {:error, Snakepit.Error.t()}
Python method SpeculativeConfig.use_eagle.
Returns
boolean()