Edifice.Blocks.KVCache (Edifice v0.2.0)

Copy Markdown View Source

KV Cache: Inference-time Key-Value Caching for Autoregressive Decoding.

Provides utilities for caching key and value projections during autoregressive generation. Without caching, each new token requires re-computing K/V for all previous tokens (O(n²) per step). With caching, only the new token's K/V are computed and appended (O(n)).

How It Works

Step 1: prompt "The cat"
  K_cache = [K_the, K_cat]
  V_cache = [V_the, V_cat]

Step 2: generate "sat"
  Compute K_sat, V_sat only for new token
  K_cache = [K_the, K_cat, K_sat]    (append)
  V_cache = [V_the, V_cat, V_sat]    (append)
  Attend: Q_sat × K_cache^T  weights  V_cache

Step 3: generate "on"
  K_cache = [K_the, K_cat, K_sat, K_on]
  ...

Usage

# Initialize empty cache for a model
cache = KVCache.init(batch_size: 2, num_layers: 4, num_heads: 4, head_dim: 64)

# During generation loop:
{new_k, new_v} = compute_kv(token)
{cache, full_k, full_v} = KVCache.update(cache, layer_idx, new_k, new_v)
output = attend(q, full_k, full_v)

# Get current sequence length
len = KVCache.seq_length(cache, layer_idx)

Design

The cache is a simple map of {layer_idx => {k_tensor, v_tensor}}. Tensors grow along the sequence dimension (axis 2 for head-first layout, axis 1 for seq-first layout). This module assumes head-first layout: [batch, num_heads, seq_len, head_dim].

Summary

Types

Options for init/1.

t()

A KV cache: map from layer index to {k_tensor, v_tensor}.

Functions

Build a cached attention function.

Get the cached K/V tensors for a layer (valid portion only).

Initialize an empty KV cache.

Reset the cache to empty (reuse the allocated buffers).

Get the current sequence length stored in the cache.

Append new K/V entries to the cache for a given layer.

Types

init_opt()

@type init_opt() ::
  {:batch_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:type, Nx.Type.t()}

Options for init/1.

t()

@type t() :: %{required(non_neg_integer()) => {Nx.Tensor.t(), Nx.Tensor.t()}}

A KV cache: map from layer index to {k_tensor, v_tensor}.

Functions

build_cached_attention(opts)

@spec build_cached_attention(keyword()) :: function()

Build a cached attention function.

Returns a function that, given Q, K, V for the new tokens and a cache state, computes attention using the full cached K/V history.

Options

  • :num_heads - Number of attention heads (required)
  • :head_dim - Dimension per head (required)
  • :layer_idx - Layer index for cache lookup (required)

Returns

A function fn(q, k, v, cache_state) -> {output, updated_cache_state}.

get(map, layer_idx)

@spec get(map(), non_neg_integer()) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Get the cached K/V tensors for a layer (valid portion only).

Returns {k, v} sliced to the current position.

init(opts)

@spec init([init_opt()]) :: map()

Initialize an empty KV cache.

Creates pre-allocated zero tensors for each layer. Using pre-allocated tensors with a position pointer is more efficient than repeated concatenation on accelerators.

Options

  • :batch_size - Batch size (required)
  • :num_layers - Number of transformer layers (required)
  • :num_heads - Number of attention heads (required)
  • :head_dim - Dimension per head (required)
  • :max_seq_len - Maximum sequence length to pre-allocate (default: 2048)
  • :type - Numeric type (default: :f32)

Returns

A map %{cache: %{layer => {k, v}}, position: 0, max_seq_len: n}.

reset(state)

@spec reset(map()) :: map()

Reset the cache to empty (reuse the allocated buffers).

seq_length(map)

@spec seq_length(map()) :: non_neg_integer()

Get the current sequence length stored in the cache.

update(state, layer_idx, new_k, new_v)

@spec update(map(), non_neg_integer(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {map(), Nx.Tensor.t(), Nx.Tensor.t()}

Append new K/V entries to the cache for a given layer.

Takes new K/V tensors of shape [batch, num_heads, new_len, head_dim] and writes them into the pre-allocated cache at the current position.

Parameters

  • state - The cache state from init/1 or a previous update/4
  • layer_idx - Which transformer layer (0-indexed)
  • new_k - New key tensor [batch, num_heads, new_len, head_dim]
  • new_v - New value tensor [batch, num_heads, new_len, head_dim]

Returns

{updated_state, cached_k, cached_v} where cached_k/v contain all entries up to and including the new ones (sliced from pre-allocated buffer).