View Source Bumblebee.Text.Generation behaviour (Bumblebee v0.6.0)

An interface for language models supporting sequence generation.

Summary

Callbacks

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Initializes an opaque cache input for iterative inference.

Traverses all batched tensors in the cache.

Functions

Builds a numerical definition that generates sequences of tokens using the given language model.

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Initializes an opaque cache input for iterative inference.

Calls fun for every batched tensor in the cache.

Types

Callbacks

Link to this callback

extra_config_module(spec)

View Source (optional)
@callback extra_config_module(spec :: Bumblebee.ModelSpec.t()) :: module()

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Link to this callback

init_cache(spec, batch_size, max_length, inputs)

View Source
@callback init_cache(
  spec :: Bumblebee.ModelSpec.t(),
  batch_size :: pos_integer(),
  max_length :: pos_integer(),
  inputs :: map()
) :: cache()

Initializes an opaque cache input for iterative inference.

Link to this callback

traverse_cache(spec, cache, function)

View Source
@callback traverse_cache(
  spec :: Bumblebee.ModelSpec.t(),
  cache(),
  (Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()

Traverses all batched tensors in the cache.

This function is used when the cache needs to be inflated or deflated for a different batch size.

Functions

Link to this function

build_generate(model, spec, config, opts \\ [])

View Source
@spec build_generate(
  Axon.t(),
  Bumblebee.ModelSpec.t(),
  Bumblebee.Text.GenerationConfig.t(),
  keyword()
) :: (params :: map(), inputs :: map() ->
        %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()}
        | (ignored :: Nx.Tensor.t()))

Builds a numerical definition that generates sequences of tokens using the given language model.

The model should be either a decoder or an encoder-decoder. The tokens are generated by iterative inference using the decoder (autoregression), until the termination criteria are met.

In case of encoder-decoder models, the corresponding encoder is run only once and the intermediate state is reused during all iterations.

The generation is controlled by a number of options given as %Bumblebee.Text.GenerationConfig{}, see the corresponding docs for more details.

Returns a defn JIT-compatible anonymous function, which expects the model params as the first argument and inputs map as the second argument. Note that the inputs map should additionally include a "seed" tensor, with one value per input in the batch.

Streaming

This function sets up a hook that is invoked after every generated token. The hook receives a map with the following attributes:

  • :token_id - the newly generated token

  • :finished? - a boolean indicating if the sequence is finished

  • :length - the current length of the generated sequence. Once the sequence is finished, the length does not increase

Each of the attributes is a tensor with a leading batch dimension.

When streaming you may not care about the output result, in which case you can enable :ignore_output to reduce the output size.

Options

  • :logits_processors - a list of numerical functions to modify predicted scores at each generation step. The functions are applied in order, after all default processors

  • :ignore_output - if true, returns a dummy tensor that should be ignored. This is useful when you consume the generated tokens in a stream fashion via the hook, so that the full output does not need to be transferred unnecessarily after the computation. Defaults to false

Link to this function

extra_config_module(spec)

View Source
@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Link to this function

init_cache(spec, batch_size, max_length, inputs)

View Source
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) ::
  cache()

Initializes an opaque cache input for iterative inference.

Link to this function

traverse_cache(spec, cache, fun)

View Source
@spec traverse_cache(
  Bumblebee.ModelSpec.t(),
  cache(),
  (Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()

Calls fun for every batched tensor in the cache.