Bumblebee.Text.Generation behaviour (Bumblebee v0.7.0)

View Source

An interface for language models supporting sequence generation.

Summary

Callbacks

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Initializes an opaque cache input for iterative inference.

Traverses all batched tensors in the cache.

Functions

Builds a numerical definition that generates sequences of tokens using the given language model.

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

Initializes an opaque cache input for iterative inference.

Calls fun for every batched tensor in the cache.

Types

cache()

@type cache() :: Nx.Tensor.t() | Nx.Container.t()

Callbacks

extra_config_module(spec)

(optional)
@callback extra_config_module(spec :: Bumblebee.ModelSpec.t()) :: module()

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

init_cache(spec, batch_size, max_length, inputs)

@callback init_cache(
  spec :: Bumblebee.ModelSpec.t(),
  batch_size :: pos_integer(),
  max_length :: pos_integer(),
  inputs :: map()
) :: cache()

Initializes an opaque cache input for iterative inference.

traverse_cache(spec, cache, function)

@callback traverse_cache(
  spec :: Bumblebee.ModelSpec.t(),
  cache(),
  (Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()

Traverses all batched tensors in the cache.

This function is used when the cache needs to be inflated or deflated for a different batch size.

Functions

build_generate(model, spec, config, opts \\ [])

@spec build_generate(
  Axon.t(),
  Bumblebee.ModelSpec.t(),
  Bumblebee.Text.GenerationConfig.t(),
  keyword()
) :: (params :: map(), inputs :: map() ->
        %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()}
        | (ignored :: Nx.Tensor.t()))

Builds a numerical definition that generates sequences of tokens using the given language model.

The model should be either a decoder or an encoder-decoder. The tokens are generated by iterative inference using the decoder (autoregression), until the termination criteria are met.

In case of encoder-decoder models, the corresponding encoder is run only once and the intermediate state is reused during all iterations.

The generation is controlled by a number of options given as %Bumblebee.Text.GenerationConfig{}, see the corresponding docs for more details.

Returns a defn JIT-compatible anonymous function, which expects the model params as the first argument and inputs map as the second argument. Note that the inputs map should additionally include a "seed" tensor, with one value per input in the batch.

Streaming

This function sets up a hook that is invoked after every generated token. The hook receives a map with the following attributes:

  • :token_id - the newly generated token

  • :finished? - a boolean indicating if the sequence is finished

  • :length - the current length of the generated sequence. Once the sequence is finished, the length does not increase

Each of the attributes is a tensor with a leading batch dimension.

When streaming you may not care about the output result, in which case you can enable :ignore_output to reduce the output size.

Options

  • :logits_processors - a list of numerical functions to modify predicted scores at each generation step. The functions are applied in order, after all default processors

  • :ignore_output - if true, returns a dummy tensor that should be ignored. This is useful when you consume the generated tokens in a stream fashion via the hook, so that the full output does not need to be transferred unnecessarily after the computation. Defaults to false

extra_config_module(spec)

@spec extra_config_module(Bumblebee.ModelSpec.t()) :: module() | nil

Returns a configuration module for extra model-specific generation attributes to extend the base Bumblebee.Text.GenerationConfig.

init_cache(spec, batch_size, max_length, inputs)

@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) ::
  cache()

Initializes an opaque cache input for iterative inference.

traverse_cache(spec, cache, fun)

@spec traverse_cache(
  Bumblebee.ModelSpec.t(),
  cache(),
  (Nx.Tensor.t() -> Nx.Tensor.t())
) :: cache()

Calls fun for every batched tensor in the cache.