Stephen.Plaid (Stephen v1.0.0)

View Source

PLAID-style indexing for efficient ColBERT retrieval.

PLAID (Performance-optimized Late Interaction Driver) uses centroid-based candidate generation for faster retrieval:

  1. Cluster all document token embeddings into K centroids
  2. Build inverted lists: centroid -> [doc_ids with tokens near that centroid]
  3. At query time, find nearest centroids for query tokens
  4. Retrieve candidate docs from inverted lists
  5. Rerank candidates with full MaxSim scoring

This achieves sub-linear search time for large collections.

Summary

Functions

Adds a single document to the index.

Removes a document from the index.

Removes multiple documents from the index.

Returns all document IDs in the index.

Gets the embeddings for a document.

Checks if a document exists in the index.

Indexes documents into the PLAID index.

Loads a PLAID index from disk.

Creates a new PLAID index.

Saves the PLAID index to disk.

Searches the PLAID index for documents matching a query.

Returns the number of documents in the index.

Updates a document in the index by replacing its embeddings.

Types

doc_id()

@type doc_id() :: term()

t()

@type t() :: %Stephen.Plaid{
  centroids: Nx.Tensor.t() | nil,
  doc_count: non_neg_integer(),
  doc_embeddings: %{required(term()) => Nx.Tensor.t()},
  embedding_dim: pos_integer(),
  inverted_index: %{required(non_neg_integer()) => MapSet.t()},
  num_centroids: pos_integer()
}

Functions

add_document(plaid, doc_id, embeddings)

@spec add_document(t(), doc_id(), Nx.Tensor.t()) :: t()

Adds a single document to the index.

delete(plaid, doc_id)

@spec delete(t(), doc_id()) :: t()

Removes a document from the index.

Arguments

  • plaid - PLAID index
  • doc_id - The document ID to remove

Returns

Updated PLAID index, or the original index if doc_id not found.

delete_all(plaid, doc_ids)

@spec delete_all(t(), [doc_id()]) :: t()

Removes multiple documents from the index.

Arguments

  • plaid - PLAID index
  • doc_ids - List of document IDs to remove

Returns

Updated PLAID index.

doc_ids(plaid)

@spec doc_ids(t()) :: [doc_id()]

Returns all document IDs in the index.

get_embeddings(plaid, doc_id)

@spec get_embeddings(t(), doc_id()) :: Nx.Tensor.t() | nil

Gets the embeddings for a document.

has_doc?(plaid, doc_id)

@spec has_doc?(t(), doc_id()) :: boolean()

Checks if a document exists in the index.

index_documents(plaid, documents)

@spec index_documents(t(), [{doc_id(), Nx.Tensor.t()}]) :: t()

Indexes documents into the PLAID index.

The first call will train centroids on the provided embeddings. Subsequent calls will use the existing centroids.

Arguments

  • plaid - PLAID index
  • documents - List of {doc_id, embeddings} tuples

load(path)

@spec load(Path.t()) :: {:ok, t()} | {:error, term()}

Loads a PLAID index from disk.

Arguments

  • path - File path to load from

Returns

{:ok, plaid} or {:error, reason}

new(opts \\ [])

@spec new(keyword()) :: t()

Creates a new PLAID index.

Options

  • :embedding_dim - Dimension of embeddings (required)
  • :num_centroids - Number of centroids for clustering (default: 1024)

Examples

plaid = Stephen.Plaid.new(embedding_dim: 128, num_centroids: 1024)

save(plaid, path)

@spec save(t(), Path.t()) :: :ok | {:error, term()}

Saves the PLAID index to disk.

Arguments

  • plaid - PLAID index to save
  • path - File path to save to

search(plaid, query_embeddings, opts \\ [])

@spec search(t(), Nx.Tensor.t(), keyword()) :: [%{doc_id: doc_id(), score: float()}]

Searches the PLAID index for documents matching a query.

Arguments

  • plaid - PLAID index
  • query_embeddings - Query token embeddings
  • opts - Search options

Options

  • :top_k - Number of results to return (default: 10)
  • :nprobe - Number of centroids to probe per query token (default: 32)

Returns

List of %{doc_id: term(), score: float()} sorted by score descending.

size(plaid)

@spec size(t()) :: non_neg_integer()

Returns the number of documents in the index.

update(plaid, doc_id, embeddings)

@spec update(t(), doc_id(), Nx.Tensor.t()) :: t()

Updates a document in the index by replacing its embeddings.

This is equivalent to deleting and re-adding the document.

Arguments

  • plaid - PLAID index
  • doc_id - The document ID to update
  • embeddings - New embeddings tensor

Returns

Updated PLAID index.