Nasty.Statistics.Neural.Transformers.TokenClassifier (Nasty v0.3.0)
View SourceToken classification layer on top of pre-trained transformers.
Supports:
- Part-of-speech (POS) tagging
- Named Entity Recognition (NER)
- Custom token classification tasks
The classifier adds a linear layer on top of transformer encoder outputs and uses softmax for multi-class classification per token.
Summary
Functions
Creates a token classifier from a pre-trained transformer model.
Predicts labels for a sequence of tokens.
Predicts labels for multiple sequences in batch.
Updates tokens with predicted labels.
Types
@type classifier() :: %{ base_model: map(), config: classifier_config(), classification_head: Axon.t() }
@type task() :: :pos_tagging | :ner | :token_classification
Functions
@spec create( map(), keyword() ) :: {:ok, classifier()} | {:error, term()}
Creates a token classifier from a pre-trained transformer model.
Options
:task- Classification task (:pos_tagging, :ner, or :token_classification):num_labels- Number of classification labels:label_map- Map from label IDs to label names:dropout_rate- Dropout rate for classification head (default: 0.1)
Examples
{:ok, base_model} = Loader.load_model(:roberta_base)
{:ok, classifier} = TokenClassifier.create(base_model,
task: :pos_tagging,
num_labels: 17,
label_map: %{0 => "NOUN", 1 => "VERB", ...}
)
@spec predict(classifier(), [Nasty.AST.Token.t()], keyword()) :: {:ok, [prediction()]} | {:error, term()}
Predicts labels for a sequence of tokens.
Returns predictions with label names and confidence scores.
Examples
{:ok, predictions} = TokenClassifier.predict(classifier, tokens)
# => [
# %{token_index: 0, label: "NOUN", label_id: 0, score: 0.95},
# %{token_index: 1, label: "VERB", label_id: 1, score: 0.89},
# ...
# ]
@spec predict_batch(classifier(), [[Nasty.AST.Token.t()]], keyword()) :: {:ok, [[prediction()]]} | {:error, term()}
Predicts labels for multiple sequences in batch.
More efficient than calling predict/3 multiple times.
Examples
{:ok, batch_predictions} = TokenClassifier.predict_batch(classifier, [tokens1, tokens2])
@spec tag_tokens(classifier(), [Nasty.AST.Token.t()], keyword()) :: {:ok, [Nasty.AST.Token.t()]} | {:error, term()}
Updates tokens with predicted labels.
Modifies the token structs to include predicted POS tags or entity labels.
Examples
{:ok, tagged_tokens} = TokenClassifier.tag_tokens(classifier, tokens, task: :pos_tagging)