@@ -421,8 +421,10 @@ def test_get_fresh_prompt(langfuse):
421
421
mock_server_call = langfuse .client .prompts .get
422
422
mock_server_call .return_value = prompt
423
423
424
- result = langfuse .get_prompt (prompt_name )
425
- mock_server_call .assert_called_once_with (prompt_name , version = None , label = None )
424
+ result = langfuse .get_prompt (prompt_name , fallback = "fallback" )
425
+ mock_server_call .assert_called_once_with (
426
+ prompt_name , version = None , label = None , request_options = {"max_retries" : 2 }
427
+ )
426
428
427
429
assert result == TextPromptClient (prompt )
428
430
@@ -480,7 +482,7 @@ def test_get_valid_cached_prompt(langfuse):
480
482
mock_server_call = langfuse .client .prompts .get
481
483
mock_server_call .return_value = prompt
482
484
483
- result_call_1 = langfuse .get_prompt (prompt_name )
485
+ result_call_1 = langfuse .get_prompt (prompt_name , fallback = "fallback" )
484
486
assert mock_server_call .call_count == 1
485
487
assert result_call_1 == prompt_client
486
488
@@ -742,3 +744,82 @@ def test_get_fresh_prompt_when_version_changes(langfuse):
742
744
result_call_2 = langfuse .get_prompt (prompt_name , version = 2 )
743
745
assert mock_server_call .call_count == 2
744
746
assert result_call_2 == version_changed_prompt_client
747
+
748
+
749
+ def test_do_not_return_fallback_if_fetch_success ():
750
+ langfuse = Langfuse ()
751
+ prompt_name = create_uuid ()
752
+ prompt_client = langfuse .create_prompt (
753
+ name = prompt_name ,
754
+ prompt = "test prompt" ,
755
+ labels = ["production" ],
756
+ )
757
+
758
+ second_prompt_client = langfuse .get_prompt (prompt_name , fallback = "fallback" )
759
+
760
+ assert prompt_client .name == second_prompt_client .name
761
+ assert prompt_client .version == second_prompt_client .version
762
+ assert prompt_client .prompt == second_prompt_client .prompt
763
+ assert prompt_client .config == second_prompt_client .config
764
+ assert prompt_client .config == {}
765
+
766
+
767
+ def test_fallback_text_prompt ():
768
+ langfuse = Langfuse ()
769
+
770
+ fallback_text_prompt = "this is a fallback text prompt with {{variable}}"
771
+
772
+ # Should throw an error if prompt not found and no fallback provided
773
+ with pytest .raises (Exception ):
774
+ langfuse .get_prompt ("nonexistent_prompt" )
775
+
776
+ prompt = langfuse .get_prompt ("nonexistent_prompt" , fallback = fallback_text_prompt )
777
+
778
+ assert prompt .prompt == fallback_text_prompt
779
+ assert (
780
+ prompt .compile (variable = "value" ) == "this is a fallback text prompt with value"
781
+ )
782
+
783
+
784
+ def test_fallback_chat_prompt ():
785
+ langfuse = Langfuse ()
786
+ fallback_chat_prompt = [
787
+ {"role" : "system" , "content" : "fallback system" },
788
+ {"role" : "user" , "content" : "fallback user name {{name}}" },
789
+ ]
790
+
791
+ # Should throw an error if prompt not found and no fallback provided
792
+ with pytest .raises (Exception ):
793
+ langfuse .get_prompt ("nonexistent_chat_prompt" , type = "chat" )
794
+
795
+ prompt = langfuse .get_prompt (
796
+ "nonexistent_chat_prompt" , type = "chat" , fallback = fallback_chat_prompt
797
+ )
798
+
799
+ assert prompt .prompt == fallback_chat_prompt
800
+ assert prompt .compile (name = "Jane" ) == [
801
+ {"role" : "system" , "content" : "fallback system" },
802
+ {"role" : "user" , "content" : "fallback user name Jane" },
803
+ ]
804
+
805
+
806
+ def test_do_not_link_observation_if_fallback ():
807
+ langfuse = Langfuse ()
808
+ trace_id = create_uuid ()
809
+
810
+ fallback_text_prompt = "this is a fallback text prompt with {{variable}}"
811
+
812
+ # Should throw an error if prompt not found and no fallback provided
813
+ with pytest .raises (Exception ):
814
+ langfuse .get_prompt ("nonexistent_prompt" )
815
+
816
+ prompt = langfuse .get_prompt ("nonexistent_prompt" , fallback = fallback_text_prompt )
817
+
818
+ langfuse .trace (id = trace_id ).generation (prompt = prompt , input = "this is a test input" )
819
+ langfuse .flush ()
820
+
821
+ api = get_api ()
822
+ trace = api .trace .get (trace_id )
823
+
824
+ assert len (trace .observations ) == 1
825
+ assert trace .observations [0 ].prompt_id is None
0 commit comments