Vishal Pandey | Applied ML Research

KV Caching

A small note on the KV Caching. In short, it is the most powerful trick that makes LLM faster during inference.

Let's deep dive into this...

KV_CACHE


1. Where the problem comes from

At inference (say, autoregressive generation), you generate tokens one at a time:

Step 1: Input prompt → predict next token.

Step 2: Append predicted token → feed full sequence again → predict next.

Step 3: Repeat…

Naïvely, at each step you’d recompute attention across the entire prefix plus all past tokens. That’s a computational cost of O(T2) as the sequence grows, way too slow.


2. The Trick: Cache K, V Once

Attention formula:

Attn(Qt,K1:t,V1:t)=softmax(QtK1:tdk)V1:t

Notice:

Keys K1:t and values V1:t come from all previous tokens (including the prefix).

At time t+1, you don’t need to recompute K1:t,V1:t. You already had them at step t.

You only need to compute Kt+1,Vt+1 for the new token, then append them.

So each step just does:

Instead of reprocessing the whole history.


3. What’s Actually Cached?

For every layer :

Store the matrices K and V for all tokens generated so far.

In practice, they’re stored in shape:

-(batch\_size, num\_heads, seq\_len, dk)

During generation:


4. Why This Helps

This is why large LMs (like GPT) can generate long sequences in real time.


5. Prefix-Tuning Twist

In Prefix-Tuning, the prefix KV pairs (PK, PV) are constant across steps (they don’t change with new tokens).

You can pre-compute them once per task and store them in the cache at the start.

At inference, they’re concatenated with the cached KV from real tokens:

That means every query token can always attend to both prefix memory and past tokens without recomputing either.

So, KV caching + prefix tuning = super efficient adaptation:


6. KV Caching in a Nutshell

That’s KV caching in detail: