KV Cache: Inference-time Key-Value Caching for Autoregressive Decoding.
Provides utilities for caching key and value projections during autoregressive generation. Without caching, each new token requires re-computing K/V for all previous tokens (O(n²) per step). With caching, only the new token's K/V are computed and appended (O(n)).
How It Works
Step 1: prompt "The cat"
K_cache = [K_the, K_cat]
V_cache = [V_the, V_cat]
Step 2: generate "sat"
Compute K_sat, V_sat only for new token
K_cache = [K_the, K_cat, K_sat] (append)
V_cache = [V_the, V_cat, V_sat] (append)
Attend: Q_sat × K_cache^T → weights → V_cache
Step 3: generate "on"
K_cache = [K_the, K_cat, K_sat, K_on]
...Usage
# Initialize empty cache for a model
cache = KVCache.init(batch_size: 2, num_layers: 4, num_heads: 4, head_dim: 64)
# During generation loop:
{new_k, new_v} = compute_kv(token)
{cache, full_k, full_v} = KVCache.update(cache, layer_idx, new_k, new_v)
output = attend(q, full_k, full_v)
# Get current sequence length
len = KVCache.seq_length(cache, layer_idx)Design
The cache is a simple map of {layer_idx => {k_tensor, v_tensor}}.
Tensors grow along the sequence dimension (axis 2 for head-first layout,
axis 1 for seq-first layout). This module assumes head-first layout:
[batch, num_heads, seq_len, head_dim].
Summary
Functions
Build a cached attention function.
Get the cached K/V tensors for a layer (valid portion only).
Initialize an empty KV cache.
Reset the cache to empty (reuse the allocated buffers).
Get the current sequence length stored in the cache.
Append new K/V entries to the cache for a given layer.
Types
@type init_opt() :: {:batch_size, pos_integer()} | {:num_layers, pos_integer()} | {:num_heads, pos_integer()} | {:head_dim, pos_integer()} | {:max_seq_len, pos_integer()} | {:type, Nx.Type.t()}
Options for init/1.
@type t() :: %{required(non_neg_integer()) => {Nx.Tensor.t(), Nx.Tensor.t()}}
A KV cache: map from layer index to {k_tensor, v_tensor}.
Functions
Build a cached attention function.
Returns a function that, given Q, K, V for the new tokens and a cache state, computes attention using the full cached K/V history.
Options
:num_heads- Number of attention heads (required):head_dim- Dimension per head (required):layer_idx- Layer index for cache lookup (required)
Returns
A function fn(q, k, v, cache_state) -> {output, updated_cache_state}.
@spec get(map(), non_neg_integer()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Get the cached K/V tensors for a layer (valid portion only).
Returns {k, v} sliced to the current position.
Initialize an empty KV cache.
Creates pre-allocated zero tensors for each layer. Using pre-allocated tensors with a position pointer is more efficient than repeated concatenation on accelerators.
Options
:batch_size- Batch size (required):num_layers- Number of transformer layers (required):num_heads- Number of attention heads (required):head_dim- Dimension per head (required):max_seq_len- Maximum sequence length to pre-allocate (default: 2048):type- Numeric type (default::f32)
Returns
A map %{cache: %{layer => {k, v}}, position: 0, max_seq_len: n}.
Reset the cache to empty (reuse the allocated buffers).
@spec seq_length(map()) :: non_neg_integer()
Get the current sequence length stored in the cache.
@spec update(map(), non_neg_integer(), Nx.Tensor.t(), Nx.Tensor.t()) :: {map(), Nx.Tensor.t(), Nx.Tensor.t()}
Append new K/V entries to the cache for a given layer.
Takes new K/V tensors of shape [batch, num_heads, new_len, head_dim]
and writes them into the pre-allocated cache at the current position.
Parameters
state- The cache state frominit/1or a previousupdate/4layer_idx- Which transformer layer (0-indexed)new_k- New key tensor[batch, num_heads, new_len, head_dim]new_v- New value tensor[batch, num_heads, new_len, head_dim]
Returns
{updated_state, cached_k, cached_v} where cached_k/v contain all
entries up to and including the new ones (sliced from pre-allocated buffer).