View Source Bumblebee.Text (Bumblebee v0.1.2)
High-level tasks related to text processing.
Link to this section Summary
Types
A single entity label.
Functions
Builds serving for the fill-mask task.
Builds serving for prompt-driven text generation.
Builds serving for text classification.
Builds serving for token classification.
Link to this section Types
@type fill_mask_input() :: String.t()
@type fill_mask_output() :: %{predictions: [fill_mask_prediction()]}
@type generation_input() :: String.t()
@type generation_output() :: %{results: [generation_result()]}
@type generation_result() :: %{text: String.t()}
@type text_classification_input() :: String.t()
@type text_classification_output() :: %{ predictions: [text_classification_prediction()] }
@type token_classification_entity() :: %{ start: non_neg_integer(), end: non_neg_integer(), score: float(), label: String.t(), phrase: String.t() }
A single entity label.
Note that start
and end
indices are expressed in terms of UTF-8
bytes.
@type token_classification_input() :: String.t()
@type token_classification_output() :: %{entities: [token_classification_entity()]}
Link to this section Functions
@spec fill_mask(Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword()) :: Nx.Serving.t()
Builds serving for the fill-mask task.
The serving accepts fill_mask_input/0
and returns fill_mask_output/0
.
A list of inputs is also supported.
In the fill-mask task, the objective is to predict a masked word in
the text. The serving expects the input to have exactly on such word,
denoted as [MASK]
.
options
Options
:top_k
- the number of top predictions to include in the output. If the configured value is higher than the number of labels, all labels are returned. Defaults to5
:compile
- compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys::batch_size
- the maximum batch size of the input. Inputs are optionally padded to always match this batch size:sequence_length
- the maximum input sequence length. Input sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure a defn compiler using
:defn_options
to maximally reduce inference time.:defn_options
- the options for JIT compilation. Defaults to[]
examples
Examples
{:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
serving = Bumblebee.Text.fill_mask(bert, tokenizer)
text = "The capital of [MASK] is Paris."
Nx.Serving.run(serving, text)
#=> %{
#=> predictions: [
#=> %{score: 0.9279842972755432, token: "france"},
#=> %{score: 0.008412551134824753, token: "brittany"},
#=> %{score: 0.007433671969920397, token: "algeria"},
#=> %{score: 0.004957548808306456, token: "department"},
#=> %{score: 0.004369721747934818, token: "reunion"}
#=> ]
#=> }
@spec generation(Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword()) :: Nx.Serving.t()
Builds serving for prompt-driven text generation.
The serving accepts generation_input/0
and returns generation_output/0
.
A list of inputs is also supported.
Note that either :max_new_tokens
or :max_length
must be specified.
options
Options
:max_new_tokens
- the maximum number of tokens to be generated, ignoring the number of tokens in the prompt:min_new_tokens
- the minimum number of tokens to be generated, ignoring the number of tokens in the prompt:compile
- compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys::batch_size
- the maximum batch size of the input. Inputs are optionally padded to always match this batch size:sequence_length
- the maximum input sequence length. Input sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure a defn compiler using
:defn_options
to maximally reduce inference time.:defn_options
- the options for JIT compilation. Defaults to[]
Also accepts all the other options of Bumblebee.Text.Generation.build_generate/3
.
examples
Examples
{:ok, gpt2} = Bumblebee.load_model({:hf, "gpt2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "gpt2"})
serving = Bumblebee.Text.generation(gpt2, tokenizer, max_new_tokens: 15)
prompt = "Elixir is a functional"
Nx.Serving.run(serving, prompt)
#=> %{
#=> results: [
#=> %{
#=> text: "Elixir is a functional programming language that is designed to be used in a variety of applications. It"
#=> }
#=> ]
#=> }
@spec text_classification( Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword() ) :: Nx.Serving.t()
Builds serving for text classification.
The serving accepts text_classification_input/0
and returns
text_classification_output/0
. A list of inputs is also supported.
options
Options
:top_k
- the number of top predictions to include in the output. If the configured value is higher than the number of labels, all labels are returned. Defaults to5
:compile
- compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys::batch_size
- the maximum batch size of the input. Inputs are optionally padded to always match this batch size:sequence_length
- the maximum input sequence length. Input sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure a defn compiler using
:defn_options
to maximally reduce inference time.:defn_options
- the options for JIT compilation. Defaults to[]
examples
Examples
{:ok, bertweet} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"})
serving = Bumblebee.Text.text_classification(bertweet, tokenizer)
text = "Cats are cute."
Nx.Serving.run(serving, text)
#=> %{
#=> predictions: [
#=> %{label: "POS", score: 0.9876555800437927},
#=> %{label: "NEU", score: 0.010068908333778381},
#=> %{label: "NEG", score: 0.002275536535307765}
#=> ]
#=> }
@spec token_classification( Bumblebee.model_info(), Bumblebee.Tokenizer.t(), keyword() ) :: Nx.Serving.t()
Builds serving for token classification.
The serving accepts token_classification_input/0
and returns
token_classification_output/0
. A list of inputs is also supported.
This function can be used for tasks such as named entity recognition (NER) or part of speech tagging (POS).
The recognized entities can optionally be aggregated into groups based on the given strategy.
options
Options
:aggregation
- an optional strategy for aggregating adjacent tokens. Token classification models output probabilities for each possible token class. The aggregation strategy takes scores for each token (which possibly represents subwords) and groups tokens into phrases which are readily interpretable as entities of a certain class. Supported aggregation strategies:nil
(default) - corresponds to no aggregation and returns the most likely label for each input token:same
- groups adjacent tokens with the same label. If the labels use beginning-inside-outside (BIO) tagging, the boundaries are respected and the prefix is omitted in the output labels
:ignored_labels
- the labels to ignore in the final output. The labels should be specified without BIO prefix. Defaults to["O"]
:compile
- compiles all computations for predefined input shapes during serving initialization. Should be a keyword list with the following keys::batch_size
- the maximum batch size of the input. Inputs are optionally padded to always match this batch size:sequence_length
- the maximum input sequence length. Input sequences are always padded/truncated to match that length
It is advised to set this option in production and also configure a defn compiler using
:defn_options
to maximally reduce inference time.:defn_options
- the options for JIT compilation. Defaults to[]
examples
Examples
{:ok, bert} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"})
serving = Bumblebee.Text.token_classification(bert, tokenizer, aggregation: :same)
text = "Rachel Green works at Ralph Lauren in New York City in the sitcom Friends"
Nx.Serving.run(serving, text)
#=> %{
#=> entities: [
#=> %{end: 12, label: "PER", phrase: "Rachel Green", score: 0.9997024834156036, start: 0},
#=> %{end: 34, label: "ORG", phrase: "Ralph Lauren", score: 0.9968731701374054, start: 22},
#=> %{end: 51, label: "LOC", phrase: "New York City", score: 0.9995547334353129, start: 38},
#=> %{end: 73, label: "MISC", phrase: "Friends", score: 0.6997143030166626, start: 66}
#=> ]
#=>}