Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6973824

Browse files
authored
feat(prompts): add fallback to get_prompts (langfuse#792)
1 parent 94c401a commit 6973824

File tree

4 files changed

+135
-13
lines changed

4 files changed

+135
-13
lines changed

‎langfuse/client.py‎

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from langfuse.api.resources.prompts.types import (
5050
CreatePromptRequest_Chat,
5151
CreatePromptRequest_Text,
52+
Prompt_Text,
53+
Prompt_Chat,
5254
)
5355
from langfuse.api.resources.trace.types.traces import Traces
5456
from langfuse.api.resources.utils.resources.pagination.types.meta_response import (
@@ -876,6 +878,7 @@ def get_prompt(
876878
label: Optional[str] = None,
877879
type: Literal["chat"],
878880
cache_ttl_seconds: Optional[int] = None,
881+
fallback: Optional[List[ChatMessageDict]] = None,
879882
) -> ChatPromptClient: ...
880883

881884
@overload
@@ -887,6 +890,7 @@ def get_prompt(
887890
label: Optional[str] = None,
888891
type: Literal["text"] = "text",
889892
cache_ttl_seconds: Optional[int] = None,
893+
fallback: Optional[str] = None,
890894
) -> TextPromptClient: ...
891895

892896
def get_prompt(
@@ -897,6 +901,7 @@ def get_prompt(
897901
label: Optional[str] = None,
898902
type: Literal["chat", "text"] = "text",
899903
cache_ttl_seconds: Optional[int] = None,
904+
fallback: Union[Optional[List[ChatMessageDict]], Optional[str]] = None,
900905
) -> PromptClient:
901906
"""Get a prompt.
902907
@@ -914,6 +919,7 @@ def get_prompt(
914919
cache_ttl_seconds: Optional[int]: Time-to-live in seconds for caching the prompt. Must be specified as a
915920
keyword argument. If not set, defaults to 60 seconds.
916921
type: Literal["chat", "text"]: The type of the prompt to retrieve. Defaults to "text".
922+
fallback: Union[Optional[List[ChatMessageDict]], Optional[str]]: The prompt string to return if fetching the prompt fails. Important on the first call where no cached prompt is available. Follows Langfuse prompt formatting with double curly braces for variables. Defaults to None.
917923
918924
Returns:
919925
The prompt object retrieved from the cache or directly fetched if not cached or expired of type
@@ -936,9 +942,40 @@ def get_prompt(
936942
cached_prompt = self.prompt_cache.get(cache_key)
937943

938944
if cached_prompt is None:
939-
return self._fetch_prompt_and_update_cache(
940-
name, version=version, label=label, ttl_seconds=cache_ttl_seconds
941-
)
945+
self.log.debug(f"Prompt '{cache_key}' not found in cache.")
946+
try:
947+
return self._fetch_prompt_and_update_cache(
948+
name, version=version, label=label, ttl_seconds=cache_ttl_seconds
949+
)
950+
except Exception as e:
951+
if fallback:
952+
self.log.warn(
953+
f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}"
954+
)
955+
956+
fallback_client_args = {
957+
"name": name,
958+
"prompt": fallback,
959+
"type": type,
960+
"version": version or 0,
961+
"config": {},
962+
"labels": [label] if label else [],
963+
"tags": [],
964+
}
965+
966+
if type == "text":
967+
return TextPromptClient(
968+
prompt=Prompt_Text(**fallback_client_args),
969+
is_fallback=True,
970+
)
971+
972+
if type == "chat":
973+
return ChatPromptClient(
974+
prompt=Prompt_Chat(**fallback_client_args),
975+
is_fallback=True,
976+
)
977+
978+
raise e
942979

943980
if cached_prompt.is_expired():
944981
try:
@@ -973,7 +1010,10 @@ def _fetch_prompt_and_update_cache(
9731010

9741011
self.log.debug(f"Fetching prompt '{cache_key}' from server...")
9751012
promptResponse = self.client.prompts.get(
976-
self._url_encode(name), version=version, label=label
1013+
self._url_encode(name),
1014+
version=version,
1015+
label=label,
1016+
request_options={"max_retries": 2},
9771017
)
9781018

9791019
if promptResponse.type == "chat":

‎langfuse/model.py‎

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ class BasePromptClient(ABC):
6161
labels: List[str]
6262
tags: List[str]
6363

64-
def __init__(self, prompt: Prompt):
64+
def __init__(self, prompt: Prompt, is_fallback: bool = False):
6565
self.name = prompt.name
6666
self.version = prompt.version
6767
self.config = prompt.config
6868
self.labels = prompt.labels
6969
self.tags = prompt.tags
70+
self.is_fallback = is_fallback
7071

7172
@abstractmethod
7273
def compile(self, **kwargs) -> Union[str, List[ChatMessage]]:
@@ -127,8 +128,8 @@ def _compile_template_string(content: str, data: Dict[str, Any] = {}) -> str:
127128

128129

129130
class TextPromptClient(BasePromptClient):
130-
def __init__(self, prompt: Prompt_Text):
131-
super().__init__(prompt)
131+
def __init__(self, prompt: Prompt_Text, is_fallback: bool = False):
132+
super().__init__(prompt, is_fallback)
132133
self.prompt = prompt.prompt
133134

134135
def compile(self, **kwargs) -> str:
@@ -158,8 +159,8 @@ def get_langchain_prompt(self):
158159

159160

160161
class ChatPromptClient(BasePromptClient):
161-
def __init__(self, prompt: Prompt_Chat):
162-
super().__init__(prompt)
162+
def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False):
163+
super().__init__(prompt, is_fallback)
163164
self.prompt = [
164165
ChatMessageDict(role=p.role, content=p.content) for p in prompt.prompt
165166
]

‎langfuse/utils/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _get_timestamp():
2121
def _create_prompt_context(
2222
prompt: typing.Optional[PromptClient] = None,
2323
):
24-
if prompt is not None:
24+
if prompt is not None and not prompt.is_fallback:
2525
return {"prompt_version": prompt.version, "prompt_name": prompt.name}
2626

2727
return {"prompt_version": None, "prompt_name": None}

‎tests/test_prompt.py‎

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,10 @@ def test_get_fresh_prompt(langfuse):
421421
mock_server_call = langfuse.client.prompts.get
422422
mock_server_call.return_value = prompt
423423

424-
result = langfuse.get_prompt(prompt_name)
425-
mock_server_call.assert_called_once_with(prompt_name, version=None, label=None)
424+
result = langfuse.get_prompt(prompt_name, fallback="fallback")
425+
mock_server_call.assert_called_once_with(
426+
prompt_name, version=None, label=None, request_options={"max_retries": 2}
427+
)
426428

427429
assert result == TextPromptClient(prompt)
428430

@@ -480,7 +482,7 @@ def test_get_valid_cached_prompt(langfuse):
480482
mock_server_call = langfuse.client.prompts.get
481483
mock_server_call.return_value = prompt
482484

483-
result_call_1 = langfuse.get_prompt(prompt_name)
485+
result_call_1 = langfuse.get_prompt(prompt_name, fallback="fallback")
484486
assert mock_server_call.call_count == 1
485487
assert result_call_1 == prompt_client
486488

@@ -742,3 +744,82 @@ def test_get_fresh_prompt_when_version_changes(langfuse):
742744
result_call_2 = langfuse.get_prompt(prompt_name, version=2)
743745
assert mock_server_call.call_count == 2
744746
assert result_call_2 == version_changed_prompt_client
747+
748+
749+
def test_do_not_return_fallback_if_fetch_success():
750+
langfuse = Langfuse()
751+
prompt_name = create_uuid()
752+
prompt_client = langfuse.create_prompt(
753+
name=prompt_name,
754+
prompt="test prompt",
755+
labels=["production"],
756+
)
757+
758+
second_prompt_client = langfuse.get_prompt(prompt_name, fallback="fallback")
759+
760+
assert prompt_client.name == second_prompt_client.name
761+
assert prompt_client.version == second_prompt_client.version
762+
assert prompt_client.prompt == second_prompt_client.prompt
763+
assert prompt_client.config == second_prompt_client.config
764+
assert prompt_client.config == {}
765+
766+
767+
def test_fallback_text_prompt():
768+
langfuse = Langfuse()
769+
770+
fallback_text_prompt = "this is a fallback text prompt with {{variable}}"
771+
772+
# Should throw an error if prompt not found and no fallback provided
773+
with pytest.raises(Exception):
774+
langfuse.get_prompt("nonexistent_prompt")
775+
776+
prompt = langfuse.get_prompt("nonexistent_prompt", fallback=fallback_text_prompt)
777+
778+
assert prompt.prompt == fallback_text_prompt
779+
assert (
780+
prompt.compile(variable="value") == "this is a fallback text prompt with value"
781+
)
782+
783+
784+
def test_fallback_chat_prompt():
785+
langfuse = Langfuse()
786+
fallback_chat_prompt = [
787+
{"role": "system", "content": "fallback system"},
788+
{"role": "user", "content": "fallback user name {{name}}"},
789+
]
790+
791+
# Should throw an error if prompt not found and no fallback provided
792+
with pytest.raises(Exception):
793+
langfuse.get_prompt("nonexistent_chat_prompt", type="chat")
794+
795+
prompt = langfuse.get_prompt(
796+
"nonexistent_chat_prompt", type="chat", fallback=fallback_chat_prompt
797+
)
798+
799+
assert prompt.prompt == fallback_chat_prompt
800+
assert prompt.compile(name="Jane") == [
801+
{"role": "system", "content": "fallback system"},
802+
{"role": "user", "content": "fallback user name Jane"},
803+
]
804+
805+
806+
def test_do_not_link_observation_if_fallback():
807+
langfuse = Langfuse()
808+
trace_id = create_uuid()
809+
810+
fallback_text_prompt = "this is a fallback text prompt with {{variable}}"
811+
812+
# Should throw an error if prompt not found and no fallback provided
813+
with pytest.raises(Exception):
814+
langfuse.get_prompt("nonexistent_prompt")
815+
816+
prompt = langfuse.get_prompt("nonexistent_prompt", fallback=fallback_text_prompt)
817+
818+
langfuse.trace(id=trace_id).generation(prompt=prompt, input="this is a test input")
819+
langfuse.flush()
820+
821+
api = get_api()
822+
trace = api.trace.get(trace_id)
823+
824+
assert len(trace.observations) == 1
825+
assert trace.observations[0].prompt_id is None

0 commit comments

Comments
 (0)