Nasty.Statistics.Neural.Transformers.Loader (Nasty v0.3.0)

View Source

Loads pre-trained transformer models from HuggingFace Hub or local paths.

Supports BERT, RoBERTa, DistilBERT, and XLM-RoBERTa models via Bumblebee.

Summary

Functions

Gets information about a specific model without loading it.

Lists all available pre-trained models.

Loads a pre-trained transformer model by name.

Checks if a model is available for a given language.

Types

model_config()

@type model_config() :: %{
  repo: String.t(),
  params: integer(),
  hidden_size: integer(),
  num_layers: integer(),
  languages: [atom()]
}

model_name()

@type model_name() ::
  :bert_base_cased
  | :bert_base_uncased
  | :roberta_base
  | :xlm_roberta_base
  | :distilbert_base

transformer_model()

@type transformer_model() :: %{
  name: model_name(),
  model_info: map(),
  tokenizer: map(),
  config: model_config(),
  serving: pid() | nil
}

Functions

get_model_info(model_name)

@spec get_model_info(model_name()) :: {:ok, model_config()} | {:error, :unknown_model}

Gets information about a specific model without loading it.

Examples

{:ok, info} = Loader.get_model_info(:bert_base_cased)
# => %{params: 110_000_000, hidden_size: 768, ...}

list_models()

@spec list_models() :: [model_name()]

Lists all available pre-trained models.

Examples

Loader.list_models()
# => [:bert_base_cased, :bert_base_uncased, :roberta_base, ...]

load_model(model_name, opts \\ [])

@spec load_model(
  model_name(),
  keyword()
) :: {:ok, transformer_model()} | {:error, term()}

Loads a pre-trained transformer model by name.

Options

  • :cache_dir - Directory to cache downloaded models (default: priv/models/transformers)
  • :backend - Nx backend to use (default: EXLA.Backend)
  • :device - Device to use (:cpu or :cuda, default: :cpu)
  • :offline - If true, only use cached models (default: false)

Examples

{:ok, model} = Loader.load_model(:roberta_base)
{:ok, model} = Loader.load_model(:xlm_roberta_base, cache_dir: "/tmp/models")

supports_language?(model_name, language)

@spec supports_language?(model_name(), atom()) :: boolean()

Checks if a model is available for a given language.

Examples

Loader.supports_language?(:xlm_roberta_base, :es)
# => true

Loader.supports_language?(:bert_base_cased, :es)
# => false