# `Edifice.Blocks.KVCache`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/blocks/kv_cache.ex#L1)

KV Cache: Inference-time Key-Value Caching for Autoregressive Decoding.

Provides utilities for caching key and value projections during
autoregressive generation. Without caching, each new token requires
re-computing K/V for all previous tokens (O(n²) per step). With
caching, only the new token's K/V are computed and appended (O(n)).

## How It Works

```
Step 1: prompt "The cat"
  K_cache = [K_the, K_cat]
  V_cache = [V_the, V_cat]

Step 2: generate "sat"
  Compute K_sat, V_sat only for new token
  K_cache = [K_the, K_cat, K_sat]    (append)
  V_cache = [V_the, V_cat, V_sat]    (append)
  Attend: Q_sat × K_cache^T → weights → V_cache

Step 3: generate "on"
  K_cache = [K_the, K_cat, K_sat, K_on]
  ...
```

## Usage

    # Initialize empty cache for a model
    cache = KVCache.init(batch_size: 2, num_layers: 4, num_heads: 4, head_dim: 64)

    # During generation loop:
    {new_k, new_v} = compute_kv(token)
    {cache, full_k, full_v} = KVCache.update(cache, layer_idx, new_k, new_v)
    output = attend(q, full_k, full_v)

    # Get current sequence length
    len = KVCache.seq_length(cache, layer_idx)

## Design

The cache is a simple map of `{layer_idx => {k_tensor, v_tensor}}`.
Tensors grow along the sequence dimension (axis 2 for head-first layout,
axis 1 for seq-first layout). This module assumes head-first layout:
`[batch, num_heads, seq_len, head_dim]`.

# `init_opt`

```elixir
@type init_opt() ::
  {:batch_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:max_seq_len, pos_integer()}
  | {:type, Nx.Type.t()}
```

Options for `init/1`.

# `t`

```elixir
@type t() :: %{required(non_neg_integer()) =&gt; {Nx.Tensor.t(), Nx.Tensor.t()}}
```

A KV cache: map from layer index to {k_tensor, v_tensor}.

# `build_cached_attention`

```elixir
@spec build_cached_attention(keyword()) :: function()
```

Build a cached attention function.

Returns a function that, given Q, K, V for the new tokens and a cache state,
computes attention using the full cached K/V history.

## Options

  - `:num_heads` - Number of attention heads (required)
  - `:head_dim` - Dimension per head (required)
  - `:layer_idx` - Layer index for cache lookup (required)

## Returns

  A function `fn(q, k, v, cache_state) -> {output, updated_cache_state}`.

# `get`

```elixir
@spec get(map(), non_neg_integer()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Get the cached K/V tensors for a layer (valid portion only).

Returns `{k, v}` sliced to the current position.

# `init`

```elixir
@spec init([init_opt()]) :: map()
```

Initialize an empty KV cache.

Creates pre-allocated zero tensors for each layer. Using pre-allocated
tensors with a position pointer is more efficient than repeated
concatenation on accelerators.

## Options

  - `:batch_size` - Batch size (required)
  - `:num_layers` - Number of transformer layers (required)
  - `:num_heads` - Number of attention heads (required)
  - `:head_dim` - Dimension per head (required)
  - `:max_seq_len` - Maximum sequence length to pre-allocate (default: 2048)
  - `:type` - Numeric type (default: `:f32`)

## Returns

  A map `%{cache: %{layer => {k, v}}, position: 0, max_seq_len: n}`.

# `reset`

```elixir
@spec reset(map()) :: map()
```

Reset the cache to empty (reuse the allocated buffers).

# `seq_length`

```elixir
@spec seq_length(map()) :: non_neg_integer()
```

Get the current sequence length stored in the cache.

# `update`

```elixir
@spec update(map(), non_neg_integer(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {map(), Nx.Tensor.t(), Nx.Tensor.t()}
```

Append new K/V entries to the cache for a given layer.

Takes new K/V tensors of shape `[batch, num_heads, new_len, head_dim]`
and writes them into the pre-allocated cache at the current position.

## Parameters

  - `state` - The cache state from `init/1` or a previous `update/4`
  - `layer_idx` - Which transformer layer (0-indexed)
  - `new_k` - New key tensor `[batch, num_heads, new_len, head_dim]`
  - `new_v` - New value tensor `[batch, num_heads, new_len, head_dim]`

## Returns

  `{updated_state, cached_k, cached_v}` where cached_k/v contain all
  entries up to and including the new ones (sliced from pre-allocated buffer).

---

*Consult [api-reference.md](api-reference.md) for complete listing*
