Nasty.Statistics.Neural.Transformers.TokenClassifier (Nasty v0.3.0)

View Source

Token classification layer on top of pre-trained transformers.

Supports:

  • Part-of-speech (POS) tagging
  • Named Entity Recognition (NER)
  • Custom token classification tasks

The classifier adds a linear layer on top of transformer encoder outputs and uses softmax for multi-class classification per token.

Summary

Functions

Creates a token classifier from a pre-trained transformer model.

Predicts labels for a sequence of tokens.

Predicts labels for multiple sequences in batch.

Updates tokens with predicted labels.

Types

classifier()

@type classifier() :: %{
  base_model: map(),
  config: classifier_config(),
  classification_head: Axon.t()
}

classifier_config()

@type classifier_config() :: %{
  task: task(),
  num_labels: integer(),
  label_map: %{required(integer()) => String.t()},
  model_name: atom(),
  dropout_rate: float()
}

prediction()

@type prediction() :: %{
  token_index: integer(),
  label: String.t(),
  label_id: integer(),
  score: float()
}

task()

@type task() :: :pos_tagging | :ner | :token_classification

Functions

create(base_model, opts)

@spec create(
  map(),
  keyword()
) :: {:ok, classifier()} | {:error, term()}

Creates a token classifier from a pre-trained transformer model.

Options

  • :task - Classification task (:pos_tagging, :ner, or :token_classification)
  • :num_labels - Number of classification labels
  • :label_map - Map from label IDs to label names
  • :dropout_rate - Dropout rate for classification head (default: 0.1)

Examples

{:ok, base_model} = Loader.load_model(:roberta_base)
{:ok, classifier} = TokenClassifier.create(base_model,
  task: :pos_tagging,
  num_labels: 17,
  label_map: %{0 => "NOUN", 1 => "VERB", ...}
)

predict(classifier, tokens, opts \\ [])

@spec predict(classifier(), [Nasty.AST.Token.t()], keyword()) ::
  {:ok, [prediction()]} | {:error, term()}

Predicts labels for a sequence of tokens.

Returns predictions with label names and confidence scores.

Examples

{:ok, predictions} = TokenClassifier.predict(classifier, tokens)
# => [
#   %{token_index: 0, label: "NOUN", label_id: 0, score: 0.95},
#   %{token_index: 1, label: "VERB", label_id: 1, score: 0.89},
#   ...
# ]

predict_batch(classifier, token_sequences, opts \\ [])

@spec predict_batch(classifier(), [[Nasty.AST.Token.t()]], keyword()) ::
  {:ok, [[prediction()]]} | {:error, term()}

Predicts labels for multiple sequences in batch.

More efficient than calling predict/3 multiple times.

Examples

{:ok, batch_predictions} = TokenClassifier.predict_batch(classifier, [tokens1, tokens2])

tag_tokens(classifier, tokens, opts \\ [])

@spec tag_tokens(classifier(), [Nasty.AST.Token.t()], keyword()) ::
  {:ok, [Nasty.AST.Token.t()]} | {:error, term()}

Updates tokens with predicted labels.

Modifies the token structs to include predicted POS tags or entity labels.

Examples

{:ok, tagged_tokens} = TokenClassifier.tag_tokens(classifier, tokens, task: :pos_tagging)