💡 Dynamic Token-Level KV Cache Selection: Use Query-Key dot products to measure pre-head KV Cache criticality at token-level.
💡 Per-head Soft Voting Mechanism: Calculate the per-head criticality, normalize through softmax, and sum for all heads, offers better performance and efficiency.
💡 Selection Cache: Allow consecutive similar queries to share token selection results, thereby reducing the selection frequency while ensuring its effectiveness.
✅ TokenSelect – A model-agnostic, training-free method for efficient and accurate long-context inference. It selectively involves a small number of critical KV cache tokens in the attention calculation without sacrificing accuracy.
📊 Result – Up to
Performance Comparison on a single A100-80G. The prompt is:
prompt = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 5000 + f"The pass key is 71432. Remember it. 71432 is the pass key. " + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 5000 + "What is the pass key?"Feel free to replicate this using the scripts/serve.sh and benchmark/send_request.py provided. Please refer to our paper for more evaluation results.
comparison.mov
TokenSelect is built on top of SGLang and FlashInfer.
conda create -n sglang python=3.10
conda activate sglang
pip install torch==2.4.0 -i https://download.pytorch.org/whl/cu121
pip install flashinfer==0.1.6+cu121torch2.4 -i https://flashinfer.ai/whl/cu121/torch2.4
pip install -r requirements.txt
Launch SGLang server with TokenSelect.
bash scripts/serve.shSend request to SGLang server using OpenAI Python Client. You can also use the benchmark/send_request.py script.
import openai
client = openai.Client(base_url=f"http://127.0.0.1:62726/v1", api_key="None")
prompt = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 1000 + f"The pass key is 71432. Remember it. 71432 is the pass key. " + "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. " * 1000 + "What is the pass key?"
response = client.chat.completions.create(
model="Qwen/Qwen2-7B-Instruct",
messages=[
{"role": "user", "content": prompt},
],
temperature=0,
)
print(response)Download data from https://github.com/OpenBMB/Infini.
# using llama3
bash scripts/infinitebench-mp-llama.sh
# using qwen2
bash scripts/infinitebench-mp-qwen.shDownload data from https://github.com/NVIDIA/RULER.
cd ruler
# using llama3
# bash run.sh model_name benchmark_name config_name port (choose an idle port)
bash scripts/run.sh llama3-8b-inst synthetic llama-token-retrieval 63333
# using qwen2
# bash run.sh model_name benchmark_name config_name port (choose an idle port)
bash scripts/run.sh qwen2-7b-inst synthetic qwen-token-retrieval 63333