Vishal Pandey | ML Research Engineer; Neuroscience

REFRAG: Recursive Fragmentation for Efficient Retrieval-Augmented Decoding

Problem Statement

Retrieval-Augmented Generation (RAG) pipelines empower LLMs by pulling external knowledge into the context window. But a fundamental issue persists:

Mathematically: if each retrieved document Dj is mapped into fragments fj,i and the LLM can only handle B tokens, we want to minimize the information loss

=FF~22

where F is the full set of fragment embeddings and F~ is their compressed reconstruction, subject to |F~|B.

This is essentially a low-rank approximation under a budget constraint.


Big Idea: Fragment then Compress

Instead of compressing entire documents:

  1. Fragment each document into semantically coherent chunks.
  2. Compress each chunk into a small latent embedding.
  3. Score relevance between query and compressed chunks.
  4. Select & Expand the most promising ones back to full token detail.

This way, the model filters through compressed summaries but attends to full tokens only for the relevant few.


Mathematical Intuitions

1. Fragment Embeddings

Each fragment fj,i is encoded as

Fj,i=ϕ(fj,i)d

where ϕ is a pretrained encoder.

2. Compression Mapping

We compress via linear projection:

zj,i=WFj,i+b,Wd×d, d<d

Interpretation: this reduces dimension but preserves dominant variance directions.

Error bound (from matrix approximation theory):

minWFWWFF2=k>dσk2

where σk are singular values of F. So compression error is exactly the “energy” in discarded dimensions.

3. Diversity-Preserving Selection

Selecting top-k fragments is not enough; we need coverage. Define objective:

maxS,|S|=kfSqzf+λ·logdet(ZSZS)

4. Cross-Attention Expansion

After selection, expand full tokens from chosen fragments and feed to LLM with cross-attention:

H=Attention(Q,K,V),Q=WQq, K=WKZS, V=WVZS

This focuses LLM compute on the most relevant token spans.


The ReFrag Pipeline

blog


Complexity Analysis

Standard Full Attention

ReFrag Attention

Total cost:

O(nCd+kLd)

Break-even condition:
When kC, ReFrag yields massive savings.


PyTorch Prototype

import torch
import torch.nn as nn

class FragmentCompressor(nn.Module):
    def __init__(self, d_in=768, d_out=128):
        super().__init__()
        self.linear = nn.Linear(d_in, d_out)
    def forward(self, x):
        return self.linear(x)

class ReFragPipeline(nn.Module):
    def __init__(self, d_in=768, d_out=128, d_model=256, n_heads=4):
        super().__init__()
        self.compressor = FragmentCompressor(d_in, d_out)
        self.cross_att = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.q_proj = nn.Linear(d_in, d_model)
        self.k_proj = nn.Linear(d_out, d_model)
        self.v_proj = nn.Linear(d_out, d_model)

    def forward(self, query, frags, topk=5):
        z = self.compressor(frags)              # Compress
        scores = torch.einsum('bd,bnd->bn', query, z)
        topk_idx = torch.topk(scores, k=topk, dim=1).indices
        batch_idx = torch.arange(frags.size(0)).unsqueeze(-1).expand_as(topk_idx)
        z_sel = z[batch_idx, topk_idx]         # Select top-k

        Q = self.q_proj(query).unsqueeze(1)
        K, V = self.k_proj(z_sel), self.v_proj(z_sel)
        out, _ = self.cross_att(Q, K, V)       # Cross-attend
        return out.squeeze(1)

# toy run
B, N, d_in = 2, 10, 768
frags = torch.randn(B, N, d_in)
query = torch.randn(B, d_in)
model = ReFragPipeline()
final_repr = model(query, frags)
print(final_repr.shape)

Discussion & Open Questions

  1. Fragment granularity: How to best split documents, sliding windows or semantic boundaries?
  2. Compression functions: Linear is simple; could autoencoders yield better trade-offs?
  3. Adaptive budgets: Can expansion k depend dynamically on query difficulty?
  4. Diversity term optimization: Exact determinant is expensive, need efficient surrogates.

Key Takeaway

ReFrag reframes retrieval compression as a fragment-level low-rank approximation with adaptive expansion. By fragmenting, compressing, scoring, and then expanding, it preserves semantic signals at a fraction of the token cost.

In short: it’s not just about making things smaller, it’s about compressing the right fragments.