View Source Bumblebee.Text.Generation behaviour (Bumblebee v0.6.0)
An interface for language models supporting sequence generation.
Summary
Callbacks
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
Initializes an opaque cache input for iterative inference.
Traverses all batched tensors in the cache.
Functions
Builds a numerical definition that generates sequences of tokens using the given language model.
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
Initializes an opaque cache input for iterative inference.
Calls fun
for every batched tensor in the cache.
Types
@type cache() :: Nx.Tensor.t() | Nx.Container.t()
Callbacks
@callback extra_config_module(spec :: Bumblebee.ModelSpec.t()) :: module()
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
@callback init_cache( spec :: Bumblebee.ModelSpec.t(), batch_size :: pos_integer(), max_length :: pos_integer(), inputs :: map() ) :: cache()
Initializes an opaque cache input for iterative inference.
@callback traverse_cache( spec :: Bumblebee.ModelSpec.t(), cache(), (Nx.Tensor.t() -> Nx.Tensor.t()) ) :: cache()
Traverses all batched tensors in the cache.
This function is used when the cache needs to be inflated or deflated for a different batch size.
Functions
@spec build_generate( Axon.t(), Bumblebee.ModelSpec.t(), Bumblebee.Text.GenerationConfig.t(), keyword() ) :: (params :: map(), inputs :: map() -> %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()} | (ignored :: Nx.Tensor.t()))
Builds a numerical definition that generates sequences of tokens using the given language model.
The model should be either a decoder or an encoder-decoder. The tokens are generated by iterative inference using the decoder (autoregression), until the termination criteria are met.
In case of encoder-decoder models, the corresponding encoder is run only once and the intermediate state is reused during all iterations.
The generation is controlled by a number of options given as
%Bumblebee.Text.GenerationConfig{}
, see the corresponding docs
for more details.
Returns a defn JIT-compatible anonymous function, which expects the
model params as the first argument and inputs map as the second
argument. Note that the inputs map should additionally include a
"seed"
tensor, with one value per input in the batch.
Streaming
This function sets up a hook that is invoked after every generated token. The hook receives a map with the following attributes:
:token_id
- the newly generated token:finished?
- a boolean indicating if the sequence is finished:length
- the current length of the generated sequence. Once the sequence is finished, the length does not increase
Each of the attributes is a tensor with a leading batch dimension.
When streaming you may not care about the output result, in which
case you can enable :ignore_output
to reduce the output size.
Options
:logits_processors
- a list of numerical functions to modify predicted scores at each generation step. The functions are applied in order, after all default processors:ignore_output
- if true, returns a dummy tensor that should be ignored. This is useful when you consume the generated tokens in a stream fashion via the hook, so that the full output does not need to be transferred unnecessarily after the computation. Defaults tofalse
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
Initializes an opaque cache input for iterative inference.
@spec traverse_cache( Bumblebee.ModelSpec.t(), cache(), (Nx.Tensor.t() -> Nx.Tensor.t()) ) :: cache()
Calls fun
for every batched tensor in the cache.