View Source Bumblebee.Text.Generation behaviour (Bumblebee v0.4.2)

An interface for language models supporting sequence generation.

Summary

Callbacks

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Initializes an opaque cache input for iterative inference.

Traverses all batched tensors in the cache.

Functions

Builds a numerical definition that generates sequences of tokens using the given language model.

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Initializes an opaque cache input for iterative inference.

Calls fun for every batched tensor in the cache.

Types

Callbacks

Link to this callback

extra_config_module(spec)

View Source (optional)
@callback extra_config_module(spec :: Bumblebee.ModelSpec.t()) :: module()

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Link to this callback

init_cache(spec, batch_size, max_length, inputs)

View Source
@callback init_cache(
  spec :: Bumblebee.ModelSpec.t(),
  batch_size :: pos_integer(),
  max_length :: pos_integer(),
  inputs :: map()
) :: cache()

Initializes an opaque cache input for iterative inference.

Link to this callback

traverse_cache(spec, cache, function)

View Source
@callback traverse_cache(
  spec :: Bumblebee.ModelSpec.t(),
  cache(),
  (Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()

Traverses all batched tensors in the cache.

This function is used when the cache needs to be inflated or deflated for a different batch size.

Functions

Link to this function

build_generate(model, spec, config, opts \\ [])

View Source
@spec build_generate(
  Axon.t(),
  Bumblebee.ModelSpec.t(),
  Bumblebee.Text.GenerationConfig.t(),
  keyword()
) :: (params :: map(), inputs :: map() -> Nx.t())

Builds a numerical definition that generates sequences of tokens using the given language model.

The model should be either a decoder or an encoder-decoder. The tokens are generated by iterative inference using the decoder (autoregression), until the termination criteria are met.

In case of encoder-decoder models, the corresponding encoder is run only once and the intermediate state is reused during all iterations.

The generation is controlled by a number of options given as %Bumblebee.Text.GenerationConfig{}, see the corresponding docs for more details.

Options

  • :seed - random seed to use when sampling. By default the current timestamp is used

  • :logits_processors - a list of numerical functions to modify predicted scores at each generation step. The functions are applied in order, after all default processors

Link to this function

extra_config_module(spec)

View Source
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Link to this function

init_cache(spec, batch_size, max_length, inputs)

View Source
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) ::
  cache()

Initializes an opaque cache input for iterative inference.

Link to this function

traverse_cache(spec, cache, fun)

View Source
@spec traverse_cache(
  Bumblebee.ModelSpec.t(),
  cache(),
  (Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()

Calls fun for every batched tensor in the cache.