Post

KV Cache: The Trick That Lets LLMs Remember Without Recomputing

KV Cache: The Trick That Lets LLMs Remember Without Recomputing

KV Cache: How LLMs Avoid Recomputing the Past

Large language models generate text one token at a time. At every step, the model attends to all previous tokens. Naively, this would require recomputing the entire attention computation for the whole sequence every time a new token is generated.

KV cache fixes this inefficiency.

Instead of recomputing attention for past tokens, the model stores previously computed Key (K) and Value (V) matrices and reuses them for future tokens.

Let’s walk through a tiny example.

Toy Setup

Vocabulary:

1
["<s>", "the", "cat", "sat"]

Prompt:

1
"<s> the cat"

Goal: predict "sat"

Model setup:

1
2
d_model = 4
num_heads = 1

Without KV Cache (Naive Approach)

Each generation step recomputes attention for all previous tokens.

Step 1: Predict "cat"

Input:

1
["<s>", "the"]

Compute query for the current token:

1
Q2 = embedding("the") * W_Q

Compute keys and values for all tokens so far:

1
2
K1,K2 = embedding(["<s>","the"]) * W_K
V1,V2 = embedding(["<s>","the"]) * W_V

Run attention:

1
Attention(Q2, K1:K2, V1:V2) → logits("cat")

Step 2: Predict "sat"

Input:

1
["<s>", "the", "cat"]

Compute query:

1
Q3 = embedding("cat") * W_Q

Recompute keys and values for all tokens again:

1
2
K1,K2,K3 = embedding(["<s>","the","cat"]) * W_K   ❌ recompute
V1,V2,V3 = embedding(["<s>","the","cat"]) * W_V   ❌ recompute

Run attention:

1
Attention(Q3, K1:K3, V1:V3) → logits("sat")

The model repeatedly recomputes attention for earlier tokens.

With KV Cache (Efficient Approach)

Instead of recomputing keys and values, we store them once and reuse them.

Step 1: Predict "cat"

Same as before.

After computing keys and values:

1
2
K_cache = [K1, K2]
V_cache = [V1, V2]

Step 2: Predict "sat"

Input:

1
["<s>", "the", "cat"]

Only compute new projections:

1
2
3
Q3 = embedding("cat") * W_Q
K3 = embedding("cat") * W_K
V3 = embedding("cat") * W_V

Append to cache:

1
2
K_cache = [K1, K2, K3]
V_cache = [V1, V2, V3]

Run attention:

1
Attention(Q3, K_cache, V_cache) → logits("sat")

Past tokens are never recomputed.

Visualizing Cache Growth

Step 1

1
2
3
4
Predict "cat"

[ Q2 | K1 K2 ] [ V1 V2 ]
Cache = [K1,K2] [V1,V2]

Step 2

1
2
3
4
Predict "sat"

[ Q3 | K1 K2 K3 ] [ V1 V2 V3 ]
Cache = [K1,K2,K3] [V1,V2,V3]

Step 3

1
2
3
4
Predict "on"

[ Q4 | K1..K4 ] [ V1..V4 ]
Cache grows by one entry

Time Complexity Analysis

To understand the time complexity of the attention mechanism, we need to establish two variables: $n$ (the sequence length) and $d_{model}$ (the embedding dimension or hidden state size).

The impact of the Key-Value (KV) cache only applies to the autoregressive decoding phase (when the model is generating text token-by-token).

Here is the breakdown of the time complexity for generating the $n$-th token in a sequence.

1. Without KV Cache (Naive Generation)

If we do not use a KV cache, the model has no “memory” of the calculations it performed for previous tokens. To generate the $n$-th token, the model must re-process the entire sequence of $n$ tokens from scratch.

  • Linear Projections ($Q, K, V$): Multiplying the $n \times d_{model}$ input matrix by the $d_{model} \times d_{model}$ weight matrices takes $O(n \cdot d_{model}^2)$.
  • Attention Scores ($Q \cdot K^T$): Multiplying the $n \times d_{model}$ Query matrix by the $d_{model} \times n$ Key matrix takes $O(n^2 \cdot d_{model})$.
  • Weighted Sum (Scores $\cdot V$): Multiplying the $n \times n$ attention matrix by the $n \times d_{model}$ Value matrix takes $O(n^2 \cdot d_{model})$.

Time Complexity for the $n$-th token: $O(n^2 \cdot d_{model} + n \cdot d_{model}^2)$

Because sequence length ($n$) is usually the scaling bottleneck, the attention complexity alone is typically expressed as $O(n^2 \cdot d_{model})$.

Note: If we are generating a full sequence of $n$ tokens without a cache, this step runs $n$ times, making the cumulative complexity a massive $O(n^3 \cdot d_{model})$.

2. With KV Cache (Optimized Generation)

During generation, the past tokens do not change, which means their Key ($K$) and Value ($V$) vectors do not change either. A KV cache stores these previously computed $K$ and $V$ vectors in memory.

When generating the $n$-th token, the model only needs to process the single new token (a $1 \times d_{model}$ vector):

  • Linear Projections: Compute $q, k, v$ for just the 1 new token. This takes $O(d_{model}^2)$.
  • Update Cache: The new $k$ and $v$ vectors are appended to the cached matrices (which now become size $n \times d_{model}$).
  • Attention Scores ($q \cdot K^T$): Multiply the $1 \times d_{model}$ query vector by the $d_{model} \times n$ cached Key matrix. This takes $O(n \cdot d_{model})$.
  • Weighted Sum (Scores $\cdot V$): Multiply the $1 \times n$ score vector by the $n \times d_{model}$ cached Value matrix. This takes $O(n \cdot d_{model})$.

Time Complexity for the $n$-th token: $O(n \cdot d_{model} + d_{model}^2)$

By caching the past, the attention bottleneck for generating a new token drops from quadratic to linear: $O(n \cdot d_{model})$.

Note: If we are generating a full sequence of $n$ tokens without a cache, this step runs $n$ times, making the cumulative complexity $O(n^2 \cdot d_{model})$.

Summary Comparison

Here is how the time complexity compares when generating the $n$-th token:

Generation PhaseAttention ComplexityTotal Complexity (incl. Projections)
Without KV Cache$O(n^2 \cdot d_{model})$$O(n^2 \cdot d_{model} + n \cdot d_{model}^2)$
With KV Cache$O(n \cdot d_{model})$$O(n \cdot d_{model} + d_{model}^2)$

The KV cache trades memory (space complexity) for speed (time complexity).

KV Cache Memory Footprint on Gemma-3 270M Model

The cache consumes memory proportional to sequence length.

Approximate memory:

1
2 × dtype_size × num_layers × h_kv × d_h × N

Where:

  • The first 2 accounts for both the Key and Value matrices
  • dtype_size: Accounts for the bytes per parameter
  • num_layers: Number of hidden layers
  • $h_{kv}$: Number of KV heads
  • $d_h$: Head Dimension
  • N: Sequence length

Example:

  • Model: Gemma-3-270M
  • Sequence length = 2048
  • We assume that we are running the model in 16-bit precision (FP16 or bfloat16), which means each parameter takes up 2 bytes of memory.
  • Number of hidden layers ($l$): 18
  • Number of KV heads ($h_{kv}$​): 1 (It uses Multi-Query Attention)
  • Head dimension ($d_h$​): 256

KV-cache size = 2 × 2 × 18 × 1 × 256 × 2048 bytes = ~36MB

This extra memory dramatically reduces compute.

Key Insight

KV cache stores the Key and Value projections of past tokens.

Each new token only computes:

  • Its own query
  • Its own key/value
  • Attention against cached history

This avoids recomputing attention for previous tokens and enables fast autoregressive generation.

Enjoyed this article? Never miss out on future posts - follow me.
This post is licensed under CC BY 4.0 by the author.