View Source Bumblebee.Text.Generation behaviour (Bumblebee v0.4.2)
An interface for language models supporting sequence generation.
Summary
Callbacks
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
Initializes an opaque cache input for iterative inference.
Traverses all batched tensors in the cache.
Functions
Builds a numerical definition that generates sequences of tokens using the given language model.
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
Initializes an opaque cache input for iterative inference.
Calls fun
for every batched tensor in the cache.
Types
@type cache() :: Nx.Tensor.t() | Nx.Container.t()
Callbacks
@callback extra_config_module(spec :: Bumblebee.ModelSpec.t()) :: module()
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
@callback init_cache( spec :: Bumblebee.ModelSpec.t(), batch_size :: pos_integer(), max_length :: pos_integer(), inputs :: map() ) :: cache()
Initializes an opaque cache input for iterative inference.
@callback traverse_cache( spec :: Bumblebee.ModelSpec.t(), cache(), (Nx.Tensor.t() -> Nx.Tensor.t()) ) :: cache()
Traverses all batched tensors in the cache.
This function is used when the cache needs to be inflated or deflated for a different batch size.
Functions
@spec build_generate( Axon.t(), Bumblebee.ModelSpec.t(), Bumblebee.Text.GenerationConfig.t(), keyword() ) :: (params :: map(), inputs :: map() -> Nx.t())
Builds a numerical definition that generates sequences of tokens using the given language model.
The model should be either a decoder or an encoder-decoder. The tokens are generated by iterative inference using the decoder (autoregression), until the termination criteria are met.
In case of encoder-decoder models, the corresponding encoder is run only once and the intermediate state is reused during all iterations.
The generation is controlled by a number of options given as
%Bumblebee.Text.GenerationConfig{}
, see the corresponding docs
for more details.
Options
:seed
- random seed to use when sampling. By default the current timestamp is used:logits_processors
- a list of numerical functions to modify predicted scores at each generation step. The functions are applied in order, after all default processors
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil
Returns a configuration module for extra model-specific generation
attributes to extend the base Bumblebee.Text.GenerationConfig
.
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
Initializes an opaque cache input for iterative inference.
@spec traverse_cache( Bumblebee.ModelSpec.t(), cache(), (Nx.Tensor.t() -> Nx.Tensor.t()) ) :: cache()
Calls fun
for every batched tensor in the cache.