Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5d357ba

Browse files
authored
feat: enable Anthropic prompt caching on system prompt and tools (#69)
* feat: enable Anthropic prompt caching on system prompt and tools Mark the rendered system prompt and the tool block with cache_control breakpoints when calling Anthropic models. The static prefix (~4-5K tokens of system prompt + 15+ tool definitions) was being re-billed at full input rate on every turn, every retry, and every research sub-agent iteration (up to 60 per task). With ephemeral cache breakpoints, subsequent turns within the 5-minute TTL are billed at cache-read pricing (~10% of input cost). Expected savings: 40-50% input tokens on multi-turn conversations, 60-80% on research sub-agent loops. Caching is GA in the Anthropic API and natively supported by litellm 1.83+ via cache_control blocks (no beta header required). Non-Anthropic models (HF router, OpenAI) are passed through unchanged. The helper does not mutate the caller's message list or tool list, so the persisted ContextManager.items history stays in its original string-content form. * refactor: hoist prompt_caching imports to module level, drop cached_ prefix
1 parent e2552e8 commit 5d357ba

4 files changed

Lines changed: 77 additions & 4 deletions

File tree

‎agent/context_manager/manager.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from jinja2 import Template
1414
from litellm import Message, acompletion
1515

16+
from agent.core.prompt_caching import with_prompt_caching
17+
1618
logger = logging.getLogger(__name__)
1719

1820
_HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2"
@@ -114,6 +116,9 @@ async def summarize_messages(
114116

115117
prompt_messages = list(messages) + [Message(role="user", content=prompt)]
116118
llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high")
119+
prompt_messages, tool_specs = with_prompt_caching(
120+
prompt_messages, tool_specs, llm_params.get("model")
121+
)
117122
response = await acompletion(
118123
messages=prompt_messages,
119124
max_completion_tokens=max_tokens,

‎agent/core/agent_loop.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from agent.config import Config
1515
from agent.core.doom_loop import check_for_doom_loop
1616
from agent.core.llm_params import _resolve_llm_params
17+
from agent.core.prompt_caching import with_prompt_caching
1718
from agent.core.session import Event, OpType, Session
1819
from agent.core.tools import ToolRouter
1920
from agent.tools.jobs_tool import CPU_FLAVORS
@@ -296,6 +297,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
296297
"""Call the LLM with streaming, emitting assistant_chunk events."""
297298
response = None
298299
_healed_effort = False # one-shot safety net per call
300+
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
299301
for _llm_attempt in range(_MAX_LLM_RETRIES):
300302
try:
301303
response = await acompletion(
@@ -390,6 +392,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
390392
"""Call the LLM without streaming, emit assistant_message at the end."""
391393
response = None
392394
_healed_effort = False
395+
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
393396
for _llm_attempt in range(_MAX_LLM_RETRIES):
394397
try:
395398
response = await acompletion(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Anthropic prompt caching breakpoints for outgoing LLM requests.
2+
3+
Caching is GA on Anthropic's API and natively supported by litellm >=1.83
4+
via ``cache_control`` blocks. We apply two breakpoints (out of 4 allowed):
5+
6+
1. The tool block — caches all tool definitions as a single prefix.
7+
2. The system message — caches the rendered system prompt.
8+
9+
Together these cover the ~4-5K static tokens that were being re-billed on
10+
every turn. Subsequent turns within the 5-minute TTL hit cache_read pricing
11+
(~10% of input cost) instead of full input.
12+
13+
Non-Anthropic models (HF router, OpenAI) are passed through unchanged.
14+
"""
15+
16+
from typing import Any
17+
18+
19+
def with_prompt_caching(
20+
messages: list[Any],
21+
tools: list[dict] | None,
22+
model_name: str | None,
23+
) -> tuple[list[Any], list[dict] | None]:
24+
"""Return (messages, tools) with cache_control breakpoints for Anthropic.
25+
26+
No-op for non-Anthropic models. Original objects are not mutated; a fresh
27+
list with replaced first message and last tool is returned, so callers
28+
that share the underlying ``ContextManager.items`` list don't see their
29+
persisted history rewritten.
30+
"""
31+
if not model_name or not model_name.startswith("anthropic/"):
32+
return messages, tools
33+
34+
if tools:
35+
new_tools = list(tools)
36+
last = dict(new_tools[-1])
37+
last["cache_control"] = {"type": "ephemeral"}
38+
new_tools[-1] = last
39+
tools = new_tools
40+
41+
if messages:
42+
first = messages[0]
43+
role = first.get("role") if isinstance(first, dict) else getattr(first, "role", None)
44+
if role == "system":
45+
content = (
46+
first.get("content")
47+
if isinstance(first, dict)
48+
else getattr(first, "content", None)
49+
)
50+
if isinstance(content, str) and content:
51+
cached_block = [{
52+
"type": "text",
53+
"text": content,
54+
"cache_control": {"type": "ephemeral"},
55+
}]
56+
new_first = {"role": "system", "content": cached_block}
57+
messages = [new_first] + list(messages[1:])
58+
59+
return messages, tools

‎agent/tools/research_tool.py‎

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from agent.core.doom_loop import check_for_doom_loop
1717
from agent.core.llm_params import _resolve_llm_params
18+
from agent.core.prompt_caching import with_prompt_caching
1819
from agent.core.session import Event
1920

2021
logger = logging.getLogger(__name__)
@@ -323,8 +324,9 @@ async def _log(text: str) -> None:
323324
),
324325
))
325326
try:
327+
_msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
326328
response = await acompletion(
327-
messages=messages,
329+
messages=_msgs,
328330
tools=None, # no tools — force text response
329331
stream=False,
330332
timeout=120,
@@ -348,9 +350,12 @@ async def _log(text: str) -> None:
348350
))
349351

350352
try:
353+
_msgs, _tools = with_prompt_caching(
354+
messages, tool_specs if tool_specs else None, llm_params.get("model")
355+
)
351356
response = await acompletion(
352-
messages=messages,
353-
tools=tool_specs if tool_specs else None,
357+
messages=_msgs,
358+
tools=_tools,
354359
tool_choice="auto",
355360
stream=False,
356361
timeout=120,
@@ -446,8 +451,9 @@ async def _log(text: str) -> None:
446451
),
447452
))
448453
try:
454+
_msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
449455
response = await acompletion(
450-
messages=messages,
456+
messages=_msgs,
451457
tools=None,
452458
stream=False,
453459
timeout=120,

0 commit comments

Comments
 (0)