KV Cache: The Trick That Lets LLMs Remember Without Recomputing
KV Cache: How LLMs Avoid Recomputing the Past
Large language models generate text one token at a time. At every step, the model attends to all previous tokens. Naively, this would require recomputing the entire attention computation for the whole sequence every time a new token is generated.
KV cache fixes this inefficiency.
Instead of recomputing attention for past tokens, the model stores previously computed Key (K) and Value (V) matrices and reuses them for future tokens.
Let’s walk through a tiny example.
Toy Setup
Vocabulary:
1
["<s>", "the", "cat", "sat"]
Prompt:
1
"<s> the cat"
Goal: predict "sat"
Model setup:
1
2
d_model = 4
num_heads = 1
Without KV Cache (Naive Approach)
Each generation step recomputes attention for all previous tokens.
Step 1: Predict "cat"
Input:
1
["<s>", "the"]
Compute query for the current token:
1
Q2 = embedding("the") * W_Q
Compute keys and values for all tokens so far:
1
2
K1,K2 = embedding(["<s>","the"]) * W_K
V1,V2 = embedding(["<s>","the"]) * W_V
Run attention:
1
Attention(Q2, K1:K2, V1:V2) → logits("cat")
Step 2: Predict "sat"
Input:
1
["<s>", "the", "cat"]
Compute query:
1
Q3 = embedding("cat") * W_Q
Recompute keys and values for all tokens again:
1
2
K1,K2,K3 = embedding(["<s>","the","cat"]) * W_K ❌ recompute
V1,V2,V3 = embedding(["<s>","the","cat"]) * W_V ❌ recompute
Run attention:
1
Attention(Q3, K1:K3, V1:V3) → logits("sat")
The model repeatedly recomputes attention for earlier tokens.
With KV Cache (Efficient Approach)
Instead of recomputing keys and values, we store them once and reuse them.
Step 1: Predict "cat"
Same as before.
After computing keys and values:
1
2
K_cache = [K1, K2]
V_cache = [V1, V2]
Step 2: Predict "sat"
Input:
1
["<s>", "the", "cat"]
Only compute new projections:
1
2
3
Q3 = embedding("cat") * W_Q
K3 = embedding("cat") * W_K
V3 = embedding("cat") * W_V
Append to cache:
1
2
K_cache = [K1, K2, K3]
V_cache = [V1, V2, V3]
Run attention:
1
Attention(Q3, K_cache, V_cache) → logits("sat")
Past tokens are never recomputed.
Visualizing Cache Growth
Step 1
1
2
3
4
Predict "cat"
[ Q2 | K1 K2 ] [ V1 V2 ]
Cache = [K1,K2] [V1,V2]
Step 2
1
2
3
4
Predict "sat"
[ Q3 | K1 K2 K3 ] [ V1 V2 V3 ]
Cache = [K1,K2,K3] [V1,V2,V3]
Step 3
1
2
3
4
Predict "on"
[ Q4 | K1..K4 ] [ V1..V4 ]
Cache grows by one entry
Time Complexity Analysis
To understand the time complexity of the attention mechanism, we need to establish two variables: $n$ (the sequence length) and $d_{model}$ (the embedding dimension or hidden state size).
The impact of the Key-Value (KV) cache only applies to the autoregressive decoding phase (when the model is generating text token-by-token).
Here is the breakdown of the time complexity for generating the $n$-th token in a sequence.
1. Without KV Cache (Naive Generation)
If we do not use a KV cache, the model has no “memory” of the calculations it performed for previous tokens. To generate the $n$-th token, the model must re-process the entire sequence of $n$ tokens from scratch.
- Linear Projections ($Q, K, V$): Multiplying the $n \times d_{model}$ input matrix by the $d_{model} \times d_{model}$ weight matrices takes $O(n \cdot d_{model}^2)$.
- Attention Scores ($Q \cdot K^T$): Multiplying the $n \times d_{model}$ Query matrix by the $d_{model} \times n$ Key matrix takes $O(n^2 \cdot d_{model})$.
- Weighted Sum (Scores $\cdot V$): Multiplying the $n \times n$ attention matrix by the $n \times d_{model}$ Value matrix takes $O(n^2 \cdot d_{model})$.
Time Complexity for the $n$-th token: $O(n^2 \cdot d_{model} + n \cdot d_{model}^2)$
Because sequence length ($n$) is usually the scaling bottleneck, the attention complexity alone is typically expressed as $O(n^2 \cdot d_{model})$.
Note: If we are generating a full sequence of $n$ tokens without a cache, this step runs $n$ times, making the cumulative complexity a massive $O(n^3 \cdot d_{model})$.
2. With KV Cache (Optimized Generation)
During generation, the past tokens do not change, which means their Key ($K$) and Value ($V$) vectors do not change either. A KV cache stores these previously computed $K$ and $V$ vectors in memory.
When generating the $n$-th token, the model only needs to process the single new token (a $1 \times d_{model}$ vector):
- Linear Projections: Compute $q, k, v$ for just the 1 new token. This takes $O(d_{model}^2)$.
- Update Cache: The new $k$ and $v$ vectors are appended to the cached matrices (which now become size $n \times d_{model}$).
- Attention Scores ($q \cdot K^T$): Multiply the $1 \times d_{model}$ query vector by the $d_{model} \times n$ cached Key matrix. This takes $O(n \cdot d_{model})$.
- Weighted Sum (Scores $\cdot V$): Multiply the $1 \times n$ score vector by the $n \times d_{model}$ cached Value matrix. This takes $O(n \cdot d_{model})$.
Time Complexity for the $n$-th token: $O(n \cdot d_{model} + d_{model}^2)$
By caching the past, the attention bottleneck for generating a new token drops from quadratic to linear: $O(n \cdot d_{model})$.
Note: If we are generating a full sequence of $n$ tokens without a cache, this step runs $n$ times, making the cumulative complexity $O(n^2 \cdot d_{model})$.
Summary Comparison
Here is how the time complexity compares when generating the $n$-th token:
| Generation Phase | Attention Complexity | Total Complexity (incl. Projections) |
|---|---|---|
| Without KV Cache | $O(n^2 \cdot d_{model})$ | $O(n^2 \cdot d_{model} + n \cdot d_{model}^2)$ |
| With KV Cache | $O(n \cdot d_{model})$ | $O(n \cdot d_{model} + d_{model}^2)$ |
The KV cache trades memory (space complexity) for speed (time complexity).
KV Cache Memory Footprint on Gemma-3 270M Model
The cache consumes memory proportional to sequence length.
Approximate memory:
1
2 × dtype_size × num_layers × h_kv × d_h × N
Where:
- The first 2 accounts for both the Key and Value matrices
- dtype_size: Accounts for the bytes per parameter
- num_layers: Number of hidden layers
- $h_{kv}$: Number of KV heads
- $d_h$: Head Dimension
- N: Sequence length
Example:
- Model: Gemma-3-270M
- Sequence length = 2048
- We assume that we are running the model in 16-bit precision (FP16 or bfloat16), which means each parameter takes up 2 bytes of memory.
- Number of hidden layers ($l$): 18
- Number of KV heads ($h_{kv}$): 1 (It uses Multi-Query Attention)
- Head dimension ($d_h$): 256
KV-cache size = 2 × 2 × 18 × 1 × 256 × 2048 bytes = ~36MB
This extra memory dramatically reduces compute.
Key Insight
KV cache stores the Key and Value projections of past tokens.
Each new token only computes:
- Its own query
- Its own key/value
- Attention against cached history
This avoids recomputing attention for previous tokens and enables fast autoregressive generation.