View Source Bumblebee (Bumblebee v0.1.2)

Pre-trained Axon models for easy inference and boosted training.

Bumblebee provides state-of-the-art, configurable Axon models. On top of that, it streamlines the process of loading pre-trained models by integrating with Hugging Face Hub and 🤗 Transformers.

usage

Usage

You can load one of the supported models by specifying the model repository:

{:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

Then you are ready to make predictions:

inputs = Bumblebee.apply_tokenizer(tokenizer, "Hello Bumblebee!")
outputs = Axon.predict(bert.model, bert.params, inputs)

For complete examples see the Examples notebook.

Note

The models are generally large, so make sure to configure an efficient Nx backend, such as EXLA or Torchx.

Link to this section Summary

Types

A model together with its state and metadata.

A location to fetch model files from.

Models

Builds an Axon model according to the given specification.

Loads a pre-trained model from a model repository.

Loads model specification from a model repository.

Featurizers

Featurizes input with the given featurizer.

Loads featurizer from a model repository.

Tokenizers

Tokenizes and encodes input with the given tokenizer.

Loads tokenizer from a model repository.

Schedulers

Loads scheduler from a model repository.

Initializes state for a new scheduler loop.

Predicts sample at the previous timestep using the given scheduler.

Functions

Builds or updates a configuration object with the given options.

Link to this section Types

@type model_info() :: %{model: Axon.t(), params: map(), spec: Bumblebee.ModelSpec.t()}

A model together with its state and metadata.

@type repository() ::
  {:hf, String.t()} | {:hf, String.t(), keyword()} | {:local, Path.t()}

A location to fetch model files from.

Can be either:

  • {:hf, repository_id} - the repository on Hugging Face. Options may be passed as the third element:

    • :revision - the specific model version to use, it can be any valid git identifier, such as branch name, tag name, or a commit hash

    • :cache_dir - the directory to store the downloaded files in. Defaults to the standard cache location for the given operating system. You can also configure it globally by setting the BUMBLEBEE_CACHE_DIR environment variable

    • :auth_token - the token to use as HTTP bearer authorization for remote files

    • :subdir - the directory within the repository where the files are located

  • {:local, directory} - the directory containing model files

Link to this section Models

@spec build_model(Bumblebee.ModelSpec.t()) :: Axon.t()

Builds an Axon model according to the given specification.

example

Example

spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128)
model = Bumblebee.build_model(spec)
Link to this function

load_model(repository, opts \\ [])

View Source
@spec load_model(
  repository(),
  keyword()
) :: {:ok, model_info()} | {:error, String.t()}

Loads a pre-trained model from a model repository.

options

Options

  • :spec - the model specification to use when building the model. By default the specification is loaded using load_spec/2

  • :module - the model specification module. By default it is inferred from the configuration file, if that is not possible, it must be specified explicitly

  • :architecture - the model architecture, must be supported by :module. By default it is inferred from the configuration file

  • :params_filename - the file with the model parameters to be loaded

  • :log_params_diff - whether to log missing, mismatched and unused parameters. Defaults to true

examples

Examples

By default the model type is inferred from configuration, so loading is as simple as:

{:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"})
%{model: model, params: params, spec: spec} = resnet

You can explicitly specify a different architecture, in which case matching parameters are still loaded:

{:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, architecture: :base)

To further customize the model, you can also pass the specification:

{:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"})
spec = Bumblebee.configure(spec, num_labels: 10)
{:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)
Link to this function

load_spec(repository, opts \\ [])

View Source
@spec load_spec(
  repository(),
  keyword()
) :: {:ok, Bumblebee.ModelSpec.t()} | {:error, String.t()}

Loads model specification from a model repository.

options

Options

  • :module - the model specification module. By default it is inferred from the configuration file, if that is not possible, it must be specified explicitly

  • :architecture - the model architecture, must be supported by :module. By default it is inferred from the configuration file

examples

Examples

{:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"})

You can explicitly specify a different architecture:

{:ok, spec} = Bumblebee.load_spec({:hf, "microsoft/resnet-50"}, architecture: :base)

Link to this section Featurizers

Link to this function

apply_featurizer(featurizer, input)

View Source
@spec apply_featurizer(Bumblebee.Featurizer.t(), any()) :: any()

Featurizes input with the given featurizer.

examples

Examples

featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer)
{:ok, img} = StbImage.read_file(path)
inputs = Bumblebee.apply_featurizer(featurizer, [img])
Link to this function

load_featurizer(repository, opts \\ [])

View Source
@spec load_featurizer(
  repository(),
  keyword()
) :: {:ok, Bumblebee.Featurizer.t()} | {:error, String.t()}

Loads featurizer from a model repository.

options

Options

  • :module - the featurizer module. By default it is inferred from the preprocessor configuration file, if that is not possible, it must be specified explicitly

examples

Examples

{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "microsoft/resnet-50"})

Link to this section Tokenizers

Link to this function

apply_tokenizer(tokenizer, input, opts \\ [])

View Source

Tokenizes and encodes input with the given tokenizer.

options

Options

  • :add_special_tokens - whether to add special tokens. Defaults to true

  • :pad_direction - the padding direction, either :right or :left. Defaults to :right

  • :return_attention_mask - whether to return attention mask for encoded sequence. Defaults to true

  • :return_token_type_ids - whether to return token type ids for encoded sequence. Defaults to true

  • :return_special_tokens_mask - whether to return special tokens mask for encoded sequence. Defaults to false

  • :return_offsets - whether to return token offsets for encoded sequence. Defaults to false

  • :length - applies fixed length padding or truncation to the given input if set

examples

Examples

tokenizer = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
inputs = Bumblebee.apply_tokenizer(tokenizer, ["The capital of France is [MASK]."])
Link to this function

load_tokenizer(repository, opts \\ [])

View Source
@spec load_tokenizer(
  repository(),
  keyword()
) :: {:ok, Bumblebee.Tokenizer.t()} | {:error, String.t()}

Loads tokenizer from a model repository.

options

Options

  • :module - the tokenizer module. By default it is inferred from the configuration files, if that is not possible, it must be specified explicitly

examples

Examples

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

Link to this section Schedulers

Link to this function

load_scheduler(repository, opts \\ [])

View Source
@spec load_scheduler(
  repository(),
  keyword()
) :: {:ok, Bumblebee.Scheduler.t()} | {:error, String.t()}

Loads scheduler from a model repository.

options

Options

  • :module - the scheduler module. By default it is inferred from the scheduler configuration file, if that is not possible, it must be specified explicitly

examples

Examples

{:ok, scheduler} =
  Bumblebee.load_scheduler({:hf, "CompVis/stable-diffusion-v1-4", subdir: "scheduler"})
Link to this function

scheduler_init(scheduler, num_steps, sample_shape)

View Source

Initializes state for a new scheduler loop.

Returns a pair of {state, timesteps}, where state is an opaque container expected by scheduler_step/4 and timesteps is a sequence of subsequent timesteps for model forward pass.

Note that the number of timesteps may not match num_steps exactly. num_steps parameterizes sampling points, however depending on the method, sampling certain points may require multiple forward passes of the model and each element in timesteps corresponds to a single forward pass.

Link to this function

scheduler_step(scheduler, state, sample, noise)

View Source

Predicts sample at the previous timestep using the given scheduler.

Takes the current sample and the noise predicted by the model at the current timestep. Returns {state, prev_sample}, where state is the updated scheduler loop state and prev_sample is the predicted sample at the previous timestep.

Note that some schedulers require several forward passes of the model (and a couple calls to this function) to make an actual prediction for the previous sample.

Link to this section Functions

Link to this function

configure(config, options \\ [])

View Source

Builds or updates a configuration object with the given options.

Expects a configuration struct or a module supporting configuration. These are usually configurable:

examples

Examples

To build a new configuration, pass a module:

featurizer = Bumblebee.configure(Bumblebee.Vision.ConvNextFeaturizer)
spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :for_image_classification)

Similarly, you can update an existing configuration:

featurizer = Bumblebee.configure(featurizer, resize_method: :bilinear)
spec = Bumblebee.configure(spec, embedding_size: 128)