@@ -492,7 +492,16 @@ def _handle_llm_end(
492
492
return
493
493
494
494
output = self ._extract_output (response )
495
- token_info = self ._extract_token_info (response )
495
+
496
+ # Only extract token info if it hasn't been set during streaming
497
+ step = self .steps [run_id ]
498
+ token_info = {}
499
+ if not (
500
+ hasattr (step , "prompt_tokens" )
501
+ and step .prompt_tokens is not None
502
+ and step .prompt_tokens > 0
503
+ ):
504
+ token_info = self ._extract_token_info (response )
496
505
497
506
self ._end_step (
498
507
run_id = run_id ,
@@ -763,6 +772,35 @@ def _handle_retriever_error(
763
772
"""Common logic for retriever error."""
764
773
self ._end_step (run_id = run_id , parent_run_id = parent_run_id , error = str (error ))
765
774
775
+ def _handle_llm_new_token (self , token : str , ** kwargs : Any ) -> Any :
776
+ """Common logic for LLM new token."""
777
+ # Safely check for chunk and usage_metadata
778
+ chunk = kwargs .get ("chunk" )
779
+ if (
780
+ chunk
781
+ and hasattr (chunk , "message" )
782
+ and hasattr (chunk .message , "usage_metadata" )
783
+ ):
784
+ usage = chunk .message .usage_metadata
785
+
786
+ # Only proceed if usage is not None
787
+ if usage :
788
+ # Extract run_id from kwargs (should be provided by LangChain)
789
+ run_id = kwargs .get ("run_id" )
790
+ if run_id and run_id in self .steps :
791
+ # Convert usage to the expected format like _extract_token_info does
792
+ token_info = {
793
+ "prompt_tokens" : usage .get ("input_tokens" , 0 ),
794
+ "completion_tokens" : usage .get ("output_tokens" , 0 ),
795
+ "tokens" : usage .get ("total_tokens" , 0 ),
796
+ }
797
+
798
+ # Update the step with token usage information
799
+ step = self .steps [run_id ]
800
+ if isinstance (step , steps .ChatCompletionStep ):
801
+ step .log (** token_info )
802
+ return
803
+
766
804
767
805
class OpenlayerHandler (OpenlayerHandlerMixin , BaseCallbackHandlerClass ): # type: ignore[misc]
768
806
"""LangChain callback handler that logs to Openlayer."""
@@ -848,7 +886,7 @@ def on_llm_error(
848
886
849
887
def on_llm_new_token (self , token : str , ** kwargs : Any ) -> Any :
850
888
"""Run on new LLM token. Only available when streaming is enabled."""
851
- pass
889
+ return self . _handle_llm_new_token ( token , ** kwargs )
852
890
853
891
def on_chain_start (
854
892
self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
@@ -1137,7 +1175,7 @@ async def on_llm_error(
1137
1175
return self ._handle_llm_error (error , ** kwargs )
1138
1176
1139
1177
async def on_llm_new_token (self , token : str , ** kwargs : Any ) -> Any :
1140
- pass
1178
+ return self . _handle_llm_new_token ( token , ** kwargs )
1141
1179
1142
1180
async def on_chain_start (
1143
1181
self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
0 commit comments