Text.Language.Classifier.Fasttext.HuffmanTree (Text v0.5.0)

Copy Markdown View Source

The Huffman tree fastText constructs over output labels for hierarchical softmax inference.

fastText uses hierarchical softmax (loss = :hs) when training models like lid.176. At inference the output projection is not a single matrix multiplication: instead a binary tree is traversed root-to-leaf, with each internal node carrying a learned vector that scores a left-vs-right decision. This module reproduces fastText's tree construction so the Elixir inference path can follow the same paths the C++ does.

Mirrors HierarchicalSoftmaxLoss::buildTree in src/loss.cc.

Tree shape

Given n labels:

  • Total nodes: 2n - 1 (numbered 0..2n-2).
  • Leaves: 0..n-1. Leaf i corresponds to label i. The label index matches the dictionary's label entry order (after fastText's count-desc sort).
  • Internal nodes: n..2n-2. Each internal node m has a learned vector stored in row m - n of the output matrix.
  • Root: 2n - 2.

Build algorithm

Labels are assumed sorted by count descending (fastText's Dictionary::threshold enforces this). Two pointers walk inward — leaf decrements from n-1 (smallest leaf counts first), node increments from n (newly-formed internal nodes, count = sum of children, naturally monotone). Each internal slot consumes the two smallest available counts.

The right child of every internal node is flagged as binary=true, meaning a score path that goes right uses log(sigmoid(score)) while the left child uses log(1 - sigmoid(score)). This matches the reference's tree_[mini[1]].binary = true line.

Summary

Types

t()

Vectorised representation of every leaf's path through the tree.

Functions

Builds a Huffman tree from a list of label counts.

Returns whether the node at index is a leaf.

Returns the node at index.

Returns the root index (always 2 * osz - 2).

Types

node_index()

@type node_index() :: non_neg_integer()

t()

@type t() :: %Text.Language.Classifier.Fasttext.HuffmanTree{
  nodes: :array.array(tree_node()),
  osz: pos_integer(),
  vectorised: vectorised_paths()
}

tree_node()

@type tree_node() :: %{
  parent: integer(),
  left: integer(),
  right: integer(),
  count: non_neg_integer(),
  binary: boolean()
}

vectorised_paths()

@type vectorised_paths() :: %{
  paths: Nx.Tensor.t(),
  signs: Nx.Tensor.t(),
  mask: Nx.Tensor.t(),
  max_depth: non_neg_integer()
}

Vectorised representation of every leaf's path through the tree.

All three tensors are shape {nlabels, max_depth}:

  • pathsint32. paths[i, j] is the output-matrix row index for the j-th internal node on leaf i's path (i.e. node_id - osz). Padding positions are 0; the corresponding mask entry is also 0 so the contribution is ignored. The choice of 0 for padding is safe because Nx.take always returns a defined value.

  • signsf32. +1.0 if the leaf branches right at this node (uses log(sigmoid(dot))), -1.0 if it branches left (uses log(1 - sigmoid(dot)) = log(sigmoid(-dot))). Padding positions are +1.0.

  • maskf32, 1.0 for valid path positions and 0.0 for padding. Multiplied into per-step log-probabilities to zero-out padding contributions.

Materialised once at build/1 time. Used by Text.Language.Classifier.Fasttext.Inference to score all leaves in a single fused EXLA kernel (one take + multiply + sigmoid + log + sum) instead of the original recursive BEAM-side DFS.

Functions

build(counts)

@spec build([non_neg_integer()]) :: t()

Builds a Huffman tree from a list of label counts.

Arguments

  • counts is a non-empty list of non-negative integers, in label order. For fastText models loaded by ModelLoader.load/2, the order matches the dictionary's label-typed entries (already count-descending).

Returns

  • A t/0 struct with all internal nodes populated.

Examples

iex> tree = Text.Language.Classifier.Fasttext.HuffmanTree.build([5, 3, 1])
iex> tree.osz
3
iex> # 3 leaves + 2 internal = 5 nodes total
iex> :array.size(tree.nodes)
5
iex> root = Text.Language.Classifier.Fasttext.HuffmanTree.root(tree)
iex> root
4

leaf?(tree, index)

@spec leaf?(t(), node_index()) :: boolean()

Returns whether the node at index is a leaf.

node_at(huffman_tree, index)

@spec node_at(t(), node_index()) :: tree_node()

Returns the node at index.

root(huffman_tree)

@spec root(t()) :: node_index()

Returns the root index (always 2 * osz - 2).