Vishal Pandey | Applied ML Research

Prefix Tuning

This is the paper introduced by Xiang Lisa Li & Percy Liang from Stanford University.

Paper Link: Prefix Tuning: Optimizing Continuous Prompts for Generation

Let's break it down.


Points from Abstract:

prefix_tuning_architecture


1. What problem are we solving?

You’ve got a big, pretrained language model fθ. You want it to perform a new task (summarize, translate, classify...) without updating all of θ (which is expensive, slow, and can cause forgetting).

Prefix-Tuning says:
→ Keep θ frozen.
→ Learn a tiny set of task-specific parameters ϕ, called a prefix, that "steer" the model.

Think of the prefix as a tiny “task memory” that the model reads before it reads the actual input.


2. First principles: Conditioning a language model

A decoder-only LM predicts tokens by:

pθ(y1,,yTx)=t=1Tpθ(ytx,y<t)

Full fine-tuning: Change θ

Prefix-Tuning: Keep θ fixed, and introduce auxiliary variables ϕ that modify computation:

pθ,ϕ(ytx,y<t)=pθ(ytx,y<t;prefix(ϕ))

We optimize ϕ to maximize likelihood on the task; θ stays frozen.


3. Where does the prefix live? (Inside attention)

Recall a single self-attention head:

Attn(Q,K,V)=softmax(QKdk)V

Normally, K,V come from the past tokens (causal mask).

Prefix-Tuning augments every layer’s attention with m extra key–value pairs that don't correspond to actual tokens.

For layer :

K=[PKK],V=[PVV]

Where:

PK,PVm×dk

are learned and task-specific (parameters ϕ).
The queries Q come from your actual input as usual.

So attention becomes:

Attn(Q,K,V)=softmax(Q[PKK]dk)[PVV]

Interpretation:
Before looking at the real context, each query attends to learned “memory slots” that encode task instructions or biases.


4. How are prefixes parameterized?

Naively, you could learn PK,PV directly.

Instead, the paper proposes a reparameterization:

Start from m virtual token embeddings:

Em×dmodel

Pass them through a small 2-layer MLP to generate per-layer key-values:

[PK,PV]=MLP(E)

This keeps parameter count small and stable.

In encoder–decoder models (like T5), you add prefixes to:


5. Training objective (simple)

Freeze θ. Optimize ϕ using maximum likelihood:

maxϕ(x,y)Dt=1Tlogpθ(ytx,y<t;prefixϕ)

Backpropagation flows only into ϕ.
→ That’s why Prefix-Tuning is parameter-efficient and fast to adapt.


6. Why does this work? (The math intuition)

Look at the attention logits with concatenated keys:

QKdkQ[PKK]dk=[Q(PK)dkQKdk]

The first block:

Q(PK)dk

adds learned, content-dependent biases to the attention distribution.

After softmax, the model can shift some probability mass toward prefix slots and mix in their values PV.

Result: You can shape what each layer “pays attention to,” effectively steering the network’s computation path without editing its weights.

Alternate view: a learned basis

Each layer’s output is an affine combination of values V.

By augmenting V with learned PV, you expand the span of representable outputs, like adding a small, learnable basis the model can project onto for this task.


7. How is this different from other PEFT methods?

PEFT: Parameter-Efficient Fine-tuning

Prompt Tuning (soft prompts only at input):

Adapters:

LoRA:

Tradeoff:
Prefixes = no weight injections, just extra KV, simple to serve and compose.


8. Quick Recap (Mental Model)


9. Pytorch-ish Pseudocode

# given frozen model, learnable PrefixModule
Kp_l, Vp_l = prefix_module()  # list over layers, shapes: [L, m, d_k]
for l in layers:
    K_star = cat([Kp_l[l], K_l], dim=1)
    V_star = cat([Vp_l[l], V_l], dim=1)
    out = attn(Q_l, K_star, V_star)  # rest unchanged
loss = cross_entropy(outputs, targets)  # backprop ONLY into prefix_module

10. Brain-Teaser

If we only added learnable keys PK, but not values PV, what behavior would you expect during attention, and why?

Queries ( Q ) can still produce attention logits with those prefix keys: logitsQ(PK). So, the attention distribution can still shift probability mass toward the prefix slots.

But since the corresponding values PV are missing (or zero), attending to them doesn’t inject any new information. It's like pointing to an empty memory slot.

That means attention is wasted on meaningless content, and the model can't effectively adapt to the task.

The real issue isn’t "no bias", it’s "no useful content."
Attention requires both:

The refinement: keys alone create positions, but without values, those positions hold nothing, so the model can’t change behavior.

A. Normal Attention (No Prefix)

The attention mechanism:

Attn(Q,K,V)=softmax(QKdk)V

With:

Each row of QKdk gives attention logits → softmax → weights to average rows of V.

So:

B. With Full Prefix (PK,PV)

Augment keys and values:

K=[PKK],V=[PVV]

Now attention becomes:

Attn(Q,K,V)=softmax(Q[PKK]dk)[PVV]

Queries can now spread attention across both:

If attention shifts to PK, the model retrieves corresponding PV, which is learnabletask-specific information enters computation.

C. With Only Prefix Keys PK (No Prefix Values)

We still augment keys:

K=[PKK]

But now the values are:

V=[0V]

So attention becomes:

Attn(Q,K,V)=softmax(Q[PKK]dk)[0V]

D. Why This Collapses

The logits term:

Q[PKK]dk=[Q(PK)dkQKdk]

So:

Attn(Q,K,V)=αprefixwasted·0+αreal·V=αreal·V

Net effect:

What happens during training?
The model learns to avoid prefix keys entirely → logits pushed to , prefix gets ignored.

Intuition Check:

Keys without values create empty memory addresses. Attention can point to them, but there's nothing to retrieve.


Soft Prompt Implementation

import torch
import torch.nn as nn

class SoftPrompt(nn.Module):
    def __init__(self, prompt_length, embedding_dim):
        super(SoftPrompt, self).__init__()
        self.prompt_embeddings = nn.Embeddings(
            prompt_length, embedding_dim
        )
    
    def forward(self, input_ids):
        prompt_ids = torch.arange(
            self.prompt_embeddings.num_embeddings,
            device = input_ids.device
        )
        prompt_embeddings = self.prompt_embeddings(prompt_ids)
        return torch.cat(
            (prompt_embeddings, input_ids), dim = 1
        )

Prefix Tuning Implementation

import torch
import torch.nn as nn

class PrefixTuning(nn.Module):
    def __init__(self, prefix_length, embedding_dim, num_layers):
        super(PrefixTuning, self).__init__()
        self.prefix_embeddings = nn.Embeddings(
            prefix_length * num_layers, embedding_dim
        )
        self.prefix_length = prefix_length
        self.num_layers = num_layers
    
    def forward(self, input_ids):
        batch_size = input_ids.size(0)
        total_prefix_length = self.prefix_length * self.num_layers